diff --git a/CHANGELOG b/CHANGELOG index b3f0e241..8e3fc211 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,15 @@ +0.8 (2013-08-10) + * Share server connections between different clients + * Add tunnelAllowedPort option to limit ports CONNECT method can connect to + * Avoid timeout too soon for frequently visited direct sites + * Fix reporting malformed requests in two cases when request has body: + - Authenticate requests + - Error occured before request is sent + * Support multi-lined headers + * Change client connection timeout to 15s + * Change as direct delta to 15 + * Provide ARMv5 binary + 0.7.6 (2013-07-28) * Fix bug for close connection response with no body * Fix response not keep alive by default diff --git a/README.md b/README.md index 2e6a0416..4dd0ade4 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ COW 是一个利用二级代理自动化穿越防火墙的 HTTP 代理服务器。它能自动检测被墙网站,仅对这些网站使用二级代理。 -当前版本:0.7.6 [CHANGELOG](CHANGELOG) +当前版本:0.8 [CHANGELOG](CHANGELOG) [![Build Status](https://travis-ci.org/cyfdecyf/cow.png?branch=master)](https://travis-ci.org/cyfdecyf/cow) **欢迎在 develop branch 进行开发并发送 pull request :)** @@ -24,7 +24,7 @@ COW 的设计目标是自动化,理想情况下用户无需关心哪些网站 - **OS X, Linux:** 执行以下命令(也可用于更新) - curl -s -L https://github.com/cyfdecyf/cow/raw/master/install-cow.sh | bash + curl -L git.io/cow | bash - 该安装脚本在 OS X 上可将 COW 设置为登录时启动 - [Linux 启动脚本](doc/init.d/cow),如何使用请参考注释(Debian 测试通过,其他 Linux 发行版应该也可使用) @@ -106,13 +106,12 @@ COW 默认配置下检测到被墙后,过两分钟再次尝试直连也是为 贡献代码: -@tevino: http parent proxy basic authentication -@xupefei: 提供 cow-hide.exe 以在 windows 上在后台执行 cow.exe +- @tevino: http parent proxy basic authentication +- @xupefei: 提供 cow-hide.exe 以在 windows 上在后台执行 cow.exe Bug reporter: -GitHub users: glacjay, trawor, Blaskyy, lucifer9, zellux, xream, hieixu, fantasticfears, perrywky, JayXon, graminc, WingGao, polong, dallascao - -Twitter users: @shao222 +- GitHub users: glacjay, trawor, Blaskyy, lucifer9, zellux, xream, hieixu, fantasticfears, perrywky, JayXon, graminc, WingGao, polong, dallascao +- Twitter users: 特别感谢 @shao222 多次帮助测试新版并报告了不少 bug, @xixitalk @glacjay 对 0.3 版本的 COW 提出了让它更加自动化的建议,使我重新考虑 COW 的设计目标并且改进成 0.5 版本之后的工作方式。 diff --git a/auth.go b/auth.go index f24570da..97308827 100644 --- a/auth.go +++ b/auth.go @@ -187,14 +187,6 @@ func Authenticate(conn *clientConn, r *Request) (err error) { if authIP(clientIP) { // IP is allowed return } - /* - // No user specified - if auth.user == "" { - sendErrorPage(conn, "403 Forbidden", "Access forbidden", - "You are not allowed to use the proxy.") - return errShouldClose - } - */ err = authUserPasswd(conn, r) if err == nil { auth.authed.add(clientIP) @@ -227,43 +219,39 @@ func genNonce() string { func calcRequestDigest(kv map[string]string, ha1, method string) string { // Refer to rfc2617 section 3.2.2.1 Request-Digest - buf := bytes.NewBufferString(ha1) - buf.WriteByte(':') - buf.WriteString(kv["nonce"]) - buf.WriteByte(':') - buf.WriteString(kv["nc"]) - buf.WriteByte(':') - buf.WriteString(kv["cnonce"]) - buf.WriteByte(':') - buf.WriteString("auth") // qop value - buf.WriteByte(':') - buf.WriteString(md5sum(method + ":" + kv["uri"])) - - return md5sum(buf.String()) + arr := []string{ + ha1, + kv["nonce"], + kv["nc"], + kv["cnonce"], + "auth", + md5sum(method + ":" + kv["uri"]), + } + return md5sum(strings.Join(arr, ":")) } func checkProxyAuthorization(conn *clientConn, r *Request) error { - debug.Println("authorization:", r.ProxyAuthorization) + if debug { + debug.Printf("cli(%s) authorization: %s\n", conn.RemoteAddr(), r.ProxyAuthorization) + } + arr := strings.SplitN(r.ProxyAuthorization, " ", 2) if len(arr) != 2 { - errl.Println("auth: malformed ProxyAuthorization header:", r.ProxyAuthorization) - return errBadRequest + return errors.New("auth: malformed ProxyAuthorization header: " + r.ProxyAuthorization) } if strings.ToLower(strings.TrimSpace(arr[0])) != "digest" { - errl.Println("auth: client using unsupported authenticate method:", arr[0]) - return errBadRequest + return errors.New("auth: method " + arr[0] + " unsupported, must use digest") } authHeader := parseKeyValueList(arr[1]) if len(authHeader) == 0 { - errl.Println("auth: empty authorization list") - return errBadRequest + return errors.New("auth: empty authorization list") } nonceTime, err := strconv.ParseInt(authHeader["nonce"], 16, 64) if err != nil { - return err + return fmt.Errorf("auth: nonce %v", err) } // If nonce time too early, reject. iOS will create a new connection to do - // authenticate. + // authentication. if time.Now().Sub(time.Unix(nonceTime, 0)) > time.Minute { return errAuthRequired } @@ -271,7 +259,7 @@ func checkProxyAuthorization(conn *clientConn, r *Request) error { user := authHeader["username"] au, ok := auth.user[user] if !ok { - errl.Println("auth: no such user:", authHeader["username"]) + errl.Printf("cli(%s) auth: no such user: %s\n", conn.RemoteAddr(), authHeader["username"]) return errAuthRequired } @@ -280,22 +268,18 @@ func checkProxyAuthorization(conn *clientConn, r *Request) error { _, portStr, _ := net.SplitHostPort(conn.LocalAddr().String()) port, _ := strconv.Atoi(portStr) if uint16(port) != au.port { - errl.Println("auth: user", user, "port not match") + errl.Printf("cli(%s) auth: user %s port not match\n", conn.RemoteAddr(), user) return errAuthRequired } } if authHeader["qop"] != "auth" { - msg := "auth: qop wrong: " + authHeader["qop"] - errl.Println(msg) - return errors.New(msg) + return errors.New("auth: qop wrong: " + authHeader["qop"]) } response, ok := authHeader["response"] if !ok { - msg := "auth: no request-digest" - errl.Println(msg) - return errors.New(msg) + return errors.New("auth: no request-digest response") } au.initHA1(user) @@ -303,7 +287,7 @@ func checkProxyAuthorization(conn *clientConn, r *Request) error { if response == digest { return nil } - errl.Println("auth: digest not match, maybe password wrong") + errl.Printf("cli(%s) auth: digest not match, maybe password wrong", conn.RemoteAddr()) return errAuthRequired } @@ -328,15 +312,13 @@ func authUserPasswd(conn *clientConn, r *Request) (err error) { } buf := new(bytes.Buffer) if err := auth.template.Execute(buf, data); err != nil { - errl.Println("Error generating auth response:", err) - return errInternal + return fmt.Errorf("error generating auth response: %v", err) } - if debug { + if bool(debug) && verbose { debug.Printf("authorization response:\n%s", buf.String()) } if _, err := conn.Write(buf.Bytes()); err != nil { - errl.Println("Sending auth response error:", err) - return errShouldClose + return fmt.Errorf("send auth response error: %v", err) } return errAuthRequired } diff --git a/config.go b/config.go index 0a945ff6..59f7904f 100644 --- a/config.go +++ b/config.go @@ -4,7 +4,6 @@ import ( "flag" "fmt" "github.com/cyfdecyf/bufio" - "io" "net" "os" "path" @@ -15,7 +14,7 @@ import ( ) const ( - version = "0.7.6" + version = "0.8" defaultListenAddr = "127.0.0.1:7777" ) @@ -26,6 +25,17 @@ const ( loadBalanceHash ) +// allow the same tunnel ports as polipo +var defaultTunnelAllowedPort = []string{ + "22", "80", "443", // ssh, http, https + "873", // rsync + "143", "220", "585", "993", // imap, imap3, imap4-ssl, imaps + "109", "110", "473", "995", // pop2, pop3, hybrid-pop, pop3s + "5222", "5269", // jabber-client, jabber-server + "2401", // cvspserver + "9418", // git +} + type Config struct { RcFile string // config file ListenAddr []string @@ -33,10 +43,9 @@ type Config struct { AlwaysProxy bool LoadBalance LoadBalanceMode - SshServer []string + TunnelAllowedPort map[string]bool // allowed ports to create tunnel - // http parent proxy - hasHttpParent bool + SshServer []string // authenticate client UserPasswd string @@ -54,6 +63,8 @@ type Config struct { // not configurable in config file PrintVer bool + + hasHttpParent bool // not config option } var config Config @@ -83,6 +94,11 @@ func init() { config.AuthTimeout = 2 * time.Hour config.DialTimeout = defaultDialTimeout config.ReadTimeout = defaultReadTimeout + + config.TunnelAllowedPort = make(map[string]bool) + for _, port := range defaultTunnelAllowedPort { + config.TunnelAllowedPort[port] = true + } } // Whether command line options specifies listen addr @@ -158,9 +174,6 @@ func (p configParser) ParseListen(val string) { return } arr := strings.Split(val, ",") - if config.ListenAddr == nil { - config.ListenAddr = make([]string, 0, len(arr)) - } for _, s := range arr { s = strings.TrimSpace(s) host, _, err := net.SplitHostPort(s) @@ -196,6 +209,17 @@ func (p configParser) ParseAddrInPAC(val string) { } } +func (p configParser) ParseTunnelAllowedPort(val string) { + arr := strings.Split(val, ",") + for _, s := range arr { + s = strings.TrimSpace(s) + if _, err := strconv.Atoi(s); err != nil { + Fatal("tunnel allowed ports", err) + } + config.TunnelAllowedPort[s] = true + } +} + // error checking is done in check config func (p configParser) ParseSocksParent(val string) { @@ -388,23 +412,15 @@ func parseConfig(path string) { IgnoreUTF8BOM(f) - fr := bufio.NewReader(f) + scanner := bufio.NewScanner(f) parser := reflect.ValueOf(configParser{}) zeroMethod := reflect.Value{} - var line string var n int - for { + for scanner.Scan() { n++ - line, err = ReadLine(fr) - if err == io.EOF { - return - } else if err != nil { - Fatalf("Error reading rc file: %v\n", err) - } - - line = strings.TrimSpace(line) + line := strings.TrimSpace(scanner.Text()) if line == "" || line[0] == '#' { continue } @@ -427,6 +443,9 @@ func parseConfig(path string) { args := []reflect.Value{reflect.ValueOf(val)} method.Call(args) } + if scanner.Err() != nil { + Fatalf("Error reading rc file: %v\n", scanner.Err()) + } } func updateConfig(nc *Config) { diff --git a/config_test.go b/config_test.go index 4df9e6b7..1e71dc40 100644 --- a/config_test.go +++ b/config_test.go @@ -17,3 +17,32 @@ func TestParseListen(t *testing.T) { t.Error("multiple listen address parse error") } } + +func TestTunnelAllowedPort(t *testing.T) { + parser := configParser{} + parser.ParseTunnelAllowedPort("1, 2, 3, 4, 5") + parser.ParseTunnelAllowedPort("6") + parser.ParseTunnelAllowedPort("7") + parser.ParseTunnelAllowedPort("8") + + testData := []struct { + port string + allowed bool + }{ + {"80", true}, // default allowd ports + {"443", true}, + {"1", true}, + {"3", true}, + {"5", true}, + {"7", true}, + {"8080", false}, + {"8388", false}, + } + + for _, td := range testData { + allowed := config.TunnelAllowedPort[td.port] + if allowed != td.allowed { + t.Errorf("port %s allowed %v, got %v\n", td.port, td.allowed, allowed) + } + } +} diff --git a/conn_pool.go b/conn_pool.go new file mode 100644 index 00000000..9af0701e --- /dev/null +++ b/conn_pool.go @@ -0,0 +1,135 @@ +// Shared server connections between different clients. + +package main + +import ( + "sync" + "time" +) + +// Maximum number of connections to a server. +const maxServerConnCnt = 20 + +// Store each server's connections in separate channels, getting +// connections for different servers can be done in parallel. +type ConnPool struct { + idleConn map[string]chan *serverConn + sync.RWMutex +} + +var connPool *ConnPool + +func initConnPool() { + connPool = new(ConnPool) + connPool.idleConn = make(map[string]chan *serverConn) +} + +func (cp *ConnPool) Get(hostPort string) *serverConn { + cp.RLock() + ch, ok := cp.idleConn[hostPort] + cp.RUnlock() + + if !ok { + return nil + } + + for { + select { + case sv := <-ch: + if sv.mayBeClosed() { + sv.Close() + continue + } + return sv + default: + return nil + } + } +} + +func (cp *ConnPool) Put(sv *serverConn) { + var ch chan *serverConn + + cp.RLock() + ch, ok := cp.idleConn[sv.hostPort] + cp.RUnlock() + + if !ok { + debug.Printf("connPool %s: new channel", sv.hostPort) + ch = make(chan *serverConn, maxServerConnCnt) + ch <- sv + cp.Lock() + cp.idleConn[sv.hostPort] = ch + cp.Unlock() + // start a new goroutine to close stale server connections + go closeStaleServerConn(ch, sv.hostPort) + return + } + + select { + case ch <- sv: + return + default: + // Simply close the connection if can't put into channel immediately. + // A better solution would remove old connections from the channel and + // add the new one. But's it's more complicated and this should happen + // rarely. + debug.Printf("connPool %s: channel full", sv.hostPort) + sv.Close() + } +} + +func closeStaleServerConn(ch chan *serverConn, hostPort string) { + // Tricky here. When removing a channel from the map, there maybe + // goroutines doing Put and Get using that channel. + + // For Get, there's no problem because it will return immediately. + // For Put, it's possible that a new connection is added to the + // channel, but the channel is no longer in the map. + // So after removed the channel from the map, we wait for several seconds + // and then close all connections left in it. + + // It's possible that Put add the connection after the final wait, but + // that should not happen in practice, and the worst result is just lost + // some memory and open fd. +done: + for { + time.Sleep(defaultServerConnTimeout) + cleanup: + for { + select { + case sv := <-ch: + if sv.mayBeClosed() { + debug.Printf("connPool channel %s: close one conn\n", hostPort) + sv.Close() + } else { + // Put it back and wait. + debug.Printf("connPool channel %s: put back conn\n", hostPort) + ch <- sv + break cleanup + } + default: + // no more connection in this channel + // remove the channel from the map + connPool.Lock() + delete(connPool.idleConn, hostPort) + connPool.Unlock() + debug.Printf("connPool channel %s: removed\n", hostPort) + break done + } + } + } + // Final wait and then close all left connections. In practice, there + // should be no other goroutines holding reference to the channel. + time.Sleep(2 * time.Second) + for { + select { + case sv := <-ch: + debug.Printf("connPool channel %s: close conn after removed\n", hostPort) + sv.Close() + default: + debug.Printf("connPool channel %s: cleanup done\n", hostPort) + return + } + } +} diff --git a/conn_pool_test.go b/conn_pool_test.go new file mode 100644 index 00000000..b603750d --- /dev/null +++ b/conn_pool_test.go @@ -0,0 +1,60 @@ +package main + +import ( + "testing" + "time" +) + +func TestGetFromEmptyPool(t *testing.T) { + initConnPool() + + // should not block + sv := connPool.Get("foo") + if sv != nil { + t.Error("get non nil server conn from empty conn pool") + } +} + +func TestConnPool(t *testing.T) { + initConnPool() + + closeOn := time.Now().Add(10 * time.Second) + conns := []*serverConn{ + {hostPort: "example.com:80", willCloseOn: closeOn}, + {hostPort: "example.com:80", willCloseOn: closeOn}, + {hostPort: "example.com:80", willCloseOn: closeOn}, + {hostPort: "example.com:443", willCloseOn: closeOn}, + {hostPort: "google.com:443", willCloseOn: closeOn}, + {hostPort: "google.com:443", willCloseOn: closeOn}, + {hostPort: "www.google.com:80", willCloseOn: closeOn}, + } + for _, sv := range conns { + connPool.Put(sv) + } + + testData := []struct { + hostPort string + found bool + }{ + {"example.com", false}, + {"example.com:80", true}, + {"example.com:80", true}, + {"example.com:80", true}, + {"example.com:80", false}, // has 3 such conn + {"www.google.com:80", true}, + } + + for _, td := range testData { + sv := connPool.Get(td.hostPort) + if td.found { + if sv == nil { + t.Error("should find conn for", td.hostPort) + } else if sv.hostPort != td.hostPort { + t.Errorf("hostPort should be: %s, got: %s\n", td.hostPort, sv.hostPort) + } + } else if sv != nil { + t.Errorf("should NOT find conn for %s, got conn for: %s\n", + td.hostPort, sv.hostPort) + } + } +} diff --git a/doc/sample-config/rc b/doc/sample-config/rc index fa8fad0c..1343734e 100644 --- a/doc/sample-config/rc +++ b/doc/sample-config/rc @@ -94,6 +94,12 @@ listen = 127.0.0.1:7777 # 最多允许使用多少个 CPU 核 #core = 2 +# 允许建立隧道连接的端口,多个端口用逗号分隔,可重复多次 +# 默认允许下列服务的端口: ssh, http, https, rsync, imap, pop, jabber, cvs, git +# 如需允许其他端口,请用该选项添加 +# 限制隧道连接的端口可以防止将运行 COW 的服务器上只监听本机 ip 的服务暴露给外部 +#tunnelAllowedPort = 80, 443 + # GFW 会使 DNS 解析超时,也可能返回错误的地址,能连接但是读不到任何内容 # 下面两个值改小一点可以加速检测网站是否被墙,但网络情况差时可能误判 diff --git a/error.go b/error.go index f77d02be..69e2023b 100644 --- a/error.go +++ b/error.go @@ -14,7 +14,7 @@ var errPageRawTmpl = `

{{.H1}}

{{.Msg}}
- Generated by COW at {{.T}} + Generated by COW ` + version + ` at {{.T}} ` diff --git a/http.go b/http.go index 414d8714..a19a4531 100644 --- a/http.go +++ b/http.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "github.com/cyfdecyf/bufio" - "io" "net" "strconv" "strings" @@ -18,8 +17,9 @@ const ( const ( statusBadReq = "400 Bad Request" + statusForbidden = "403 Forbidden" statusExpectFailed = "417 Expectation Failed" - statusRequestTimeout = "408 Request Time-out" + statusRequestTimeout = "408 Request Timeout" ) type Header struct { @@ -27,6 +27,7 @@ type Header struct { KeepAlive time.Duration ProxyAuthorization string Chunking bool + Trailer bool ConnectionKeepAlive bool ExpectContinue bool } @@ -92,6 +93,13 @@ func (r *Request) Verbose() []byte { return rqbyte } +// Message body in request is signaled by the inclusion of a Content-Length +// or Transfer-Encoding header. +// Refer to http://stackoverflow.com/a/299696/306935 +func (r *Request) hasBody() bool { + return r.Chunking || r.ContLen > 0 +} + func (r *Request) isRetry() bool { return r.tryCnt > 1 } @@ -108,6 +116,10 @@ func (r *Request) responseNotSent() bool { return r.state <= rsSent } +func (r *Request) hasSent() bool { + return r.state >= rsSent +} + func (r *Request) releaseBuf() { if r.raw != nil { httpBuf.Put(r.rawByte) @@ -294,6 +306,7 @@ var headerParser = map[string]HeaderParserFunc{ headerProxyAuthorization: (*Header).parseProxyAuthorization, headerProxyConnection: (*Header).parseConnection, headerTransferEncoding: (*Header).parseTransferEncoding, + headerTrailer: (*Header).parseTrailer, headerExpect: (*Header).parseExpect, } @@ -316,7 +329,7 @@ var hopByHopHeader = map[string]bool{ // If Header needs to hold raw value, make a copy. For example, // parseProxyAuthorization does this. -type HeaderParserFunc func(*Header, []byte, *bytes.Buffer) error +type HeaderParserFunc func(*Header, []byte) error // Used by both "Connection" and "Proxy-Connection" header. COW always adds // connection header at the end of a request/response (in parseRequest and @@ -324,18 +337,18 @@ type HeaderParserFunc func(*Header, []byte, *bytes.Buffer) error // This will change the order of headers, but should be OK as RFC2616 4.2 says // header order is not significant. (Though general-header first is "good- // practice".) -func (h *Header) parseConnection(s []byte, raw *bytes.Buffer) error { +func (h *Header) parseConnection(s []byte) error { ASCIIToLowerInplace(s) h.ConnectionKeepAlive = !bytes.Contains(s, []byte("close")) return nil } -func (h *Header) parseContentLength(s []byte, raw *bytes.Buffer) (err error) { +func (h *Header) parseContentLength(s []byte) (err error) { h.ContLen, err = ParseIntFromBytes(s, 10) return err } -func (h *Header) parseKeepAlive(s []byte, raw *bytes.Buffer) (err error) { +func (h *Header) parseKeepAlive(s []byte) (err error) { ASCIIToLowerInplace(s) id := bytes.Index(s, []byte("timeout=")) if id != -1 { @@ -343,37 +356,55 @@ func (h *Header) parseKeepAlive(s []byte, raw *bytes.Buffer) (err error) { end := id for ; end < len(s) && IsDigit(s[end]); end++ { } - delta, _ := ParseIntFromBytes(s[id:end], 10) + delta, err := ParseIntFromBytes(s[id:end], 10) + if err != nil { + return err // possible empty bytes + } h.KeepAlive = time.Second * time.Duration(delta) } return nil } -func (h *Header) parseProxyAuthorization(s []byte, raw *bytes.Buffer) error { +func (h *Header) parseProxyAuthorization(s []byte) error { h.ProxyAuthorization = string(s) return nil } -func (h *Header) parseTransferEncoding(s []byte, raw *bytes.Buffer) error { +func (h *Header) parseTransferEncoding(s []byte) error { ASCIIToLowerInplace(s) // For transfer-encoding: identify, it's the same as specifying neither // content-length nor transfer-encoding. h.Chunking = bytes.Contains(s, []byte("chunked")) - if h.Chunking { - raw.WriteString(fullHeaderTransferEncoding) - } else if !bytes.Contains(s, []byte("identity")) { - errl.Printf("invalid transfer encoding: %s\n", s) - return errNotSupported + if !h.Chunking && !bytes.Contains(s, []byte("identity")) { + return fmt.Errorf("invalid transfer encoding: %s", s) } return nil } -// For now, cow does not fully support 100-continue. It will return "417 +// RFC 2616 3.6.1 states when trailers are allowed: +// +// a) request includes TE header +// b) server is the original server +// +// Even though COW removes TE header, the original server can still respond +// with Trailer header. +// As Trailer is general header, it's possible to appear in request. But is +// there any client does this? +func (h *Header) parseTrailer(s []byte) error { + // use errl to test if this header is common to see + errl.Printf("got Trailer header: %s\n", s) + if len(s) != 0 { + h.Trailer = true + } + return nil +} + +// For now, COW does not fully support 100-continue. It will return "417 // expectation failed" if a request contains expect header. This is one of the // strategies supported by polipo, which is easiest to implement in cow. // TODO If we see lots of expect 100-continue usage, provide full support. -func (h *Header) parseExpect(s []byte, raw *bytes.Buffer) error { +func (h *Header) parseExpect(s []byte) error { ASCIIToLowerInplace(s) errl.Printf("Expect header: %s\n", s) // put here to see if expect header is widely used h.ExpectContinue = true @@ -386,60 +417,115 @@ func (h *Header) parseExpect(s []byte, raw *bytes.Buffer) error { } func splitHeader(s []byte) (name, val []byte, err error) { - var f [][]byte - if f = bytes.SplitN(s, []byte{':'}, 2); len(f) != 2 { - errl.Printf("malformed header: %s\n", s) - return nil, nil, errMalformHeader + i := bytes.IndexByte(s, ':') + if i < 0 { + return nil, nil, fmt.Errorf("malformed header: %#v", string(s)) } // Do not lower case field value, as it maybe case sensitive - return ASCIIToLower(f[0]), f[1], nil + return ASCIIToLower(s[:i]), TrimSpace(s[i+1:]), nil +} + +// Learned from net.textproto. One difference is that this one keeps the +// ending '\n' in the returned line. Buf if there's only CRLF in the line, +// return nil for the line. +func readContinuedLineSlice(r *bufio.Reader) ([]byte, error) { + // feedly.com request headers contains things like: + // "$Authorization.feedly: $FeedlyAuth\r\n", so we must test for only + // continuation spaces. + isspace := func(b byte) bool { + return b == ' ' || b == '\t' + } + + // Read the first line. + line, err := r.ReadSlice('\n') + if err != nil { + return nil, err + } + + // There are servers that use \n for line ending, so trim first before check ending. + // For example, the 404 page for http://plan9.bell-labs.com/magic/man2html/1/2l + trimmed := TrimSpace(line) + if len(trimmed) == 0 { + if len(line) > 2 { + return nil, fmt.Errorf("malformed end of headers, len: %d, %#v", len(line), string(line)) + } + return nil, nil + } + + if isspace(line[0]) { + return nil, fmt.Errorf("malformed header, start with space: %#v", string(line)) + } + + // Optimistically assume that we have started to buffer the next line + // and it starts with an ASCII letter (the next header key), so we can + // avoid copying that buffered data around in memory and skipping over + // non-existent whitespace. + if r.Buffered() > 0 { + peek, err := r.Peek(1) + if err == nil && !isspace(peek[0]) { + return line, nil + } + } + + var buf []byte + buf = append(buf, trimmed...) + + // Read continuation lines. + for skipSpace(r) > 0 { + line, err := r.ReadSlice('\n') + if err != nil { + break + } + buf = append(buf, ' ') + buf = append(buf, TrimTrailingSpace(line)...) + } + buf = append(buf, '\r', '\n') + return buf, nil +} + +func skipSpace(r *bufio.Reader) int { + n := 0 + for { + c, err := r.ReadByte() + if err != nil { + // Bufio will keep err until next read. + break + } + if c != ' ' && c != '\t' { + r.UnreadByte() + break + } + n++ + } + return n } // Only add headers that are of interest for a proxy into request/response's header map. func (h *Header) parseHeader(reader *bufio.Reader, raw *bytes.Buffer, url *URL) (err error) { h.ContLen = -1 - dummyLastLine := []byte{} - // Read request header and body - var s, name, val, lastLine []byte for { - if s, err = reader.ReadSlice('\n'); err != nil { + var line, name, val []byte + if line, err = readContinuedLineSlice(reader); err != nil || len(line) == 0 { return } - // There are servers that use \n for line ending, so trim first before check ending. - // For example, the 404 page for http://plan9.bell-labs.com/magic/man2html/1/2l - trimmed := TrimSpace(s) - if len(trimmed) == 0 { // end of headers - return - } - if (s[0] == ' ' || s[0] == '\t') && lastLine != nil { // multi-line header - // I've never seen multi-line header used in headers that's of interest. - // Disable multi-line support to avoid copy for now. - errl.Printf("Multi-line support disabled: %v %s", url, s) - return errNotSupported - // combine previous line with current line - // trimmed = bytes.Join([][]byte{lastLine, []byte{' '}, trimmed}, nil) - } - if name, val, err = splitHeader(trimmed); err != nil { + if name, val, err = splitHeader(line); err != nil { + errl.Printf("%v raw header:\n%s\n", err, raw.Bytes()) return } // Wait Go to solve/provide the string<->[]byte optimization kn := string(name) if parseFunc, ok := headerParser[kn]; ok { - // lastLine = append([]byte(nil), trimmed...) // copy to avoid next read invalidating the trimmed line - lastLine = dummyLastLine - val = TrimSpace(val) if len(val) == 0 { continue } - parseFunc(h, val, raw) - } else { - // mark this header as not of interest to proxy - lastLine = nil + if err = parseFunc(h, val); err != nil { + return + } } if hopByHopHeader[kn] { continue } - raw.Write(s) + raw.Write(line) // debug.Printf("len %d %s", len(s), s) } } @@ -449,8 +535,7 @@ func parseRequest(c *clientConn, r *Request) (err error) { var s []byte reader := c.bufRd // make actual timeout a little longer than keep-alive value sent to client - setConnReadTimeout(c.Conn, - clientConnTimeout+time.Duration(c.timeoutCnt)*time.Second, "parseRequest") + setConnReadTimeout(c.Conn, clientConnTimeout+2*time.Second, "parseRequest") // parse request line if s, err = reader.ReadSlice('\n'); err != nil { if isErrTimeout(err) { @@ -471,7 +556,7 @@ func parseRequest(c *clientConn, r *Request) (err error) { var f [][]byte // Tolerate with multiple spaces and '\t' is achieved by FieldsN. if f = FieldsN(s, 3); len(f) != 3 { - return errors.New(fmt.Sprintf("malformed HTTP request: %s", s)) + return fmt.Errorf("malformed request line: %#v", string(s)) } ASCIIToUpperInplace(f[0]) r.Method = string(f[0]) @@ -498,11 +583,14 @@ func parseRequest(c *clientConn, r *Request) (err error) { } r.headStart = r.raw.Len() - // Read request header + // Read request header. if err = r.parseHeader(reader, r.raw, r.URL); err != nil { - errl.Printf("Parsing request header: %v %v\n", err, r) + errl.Printf("parse request header: %v %s\n%s", err, r, r.Verbose()) return err } + if r.Chunking { + r.raw.WriteString(fullHeaderTransferEncoding) + } if r.ConnectionKeepAlive { r.raw.WriteString(fullHeaderConnectionKeepAlive) } else { @@ -516,19 +604,10 @@ func parseRequest(c *clientConn, r *Request) (err error) { return } -func skipCRLF(r *bufio.Reader) error { - // There maybe servers using single '\n' for line ending - if _, err := r.ReadSlice('\n'); err != nil { - errl.Println("Error reading CRLF:", err) - return err - } - return nil -} - // If an http response may have message body func (rp *Response) hasBody(method string) bool { if method == "HEAD" || rp.Status == 304 || rp.Status == 204 || - (100 <= rp.Status && rp.Status < 200) { + rp.Status < 200 { return false } return true @@ -542,11 +621,13 @@ func parseResponse(sv *serverConn, r *Request, rp *Response) (err error) { sv.setReadTimeout("parseResponse") } if s, err = reader.ReadSlice('\n'); err != nil { - if err != io.EOF { - // err maybe timeout caused by explicity setting deadline - debug.Printf("Reading Response status line: %v %v\n", err, r) - } - // For timeout, the connection will not be used, so no need to unset timeout + // err maybe timeout caused by explicity setting deadline, EOF, or + // reset caused by GFW. + debug.Printf("read response status line %v %v\n", err, r) + // Server connection with error will not be used any more, so no need + // to unset timeout. + // For read error, return directly in order to identify whether this + // is caused by GFW. return err } if sv.maybeFake() { @@ -557,16 +638,14 @@ func parseResponse(sv *serverConn, r *Request, rp *Response) (err error) { // response status line parsing var f [][]byte if f = FieldsN(s, 3); len(f) < 2 { // status line are separated by SP - errl.Printf("Malformed HTTP response status line: %s %v\n", s, r) - return errMalformResponse + return fmt.Errorf("malformed response status line: %#v %v", string(s), r) } status, err := ParseIntFromBytes(f[1], 10) rp.reset() rp.Status = int(status) if err != nil { - errl.Printf("response status not valid: %s len=%d %v\n", f[1], len(f[1]), err) - return + return fmt.Errorf("response status not valid: %s len=%d %v", f[1], len(f[1]), err) } if len(f) == 3 { rp.Reason = f[2] @@ -574,8 +653,7 @@ func parseResponse(sv *serverConn, r *Request, rp *Response) (err error) { proto := f[0] if !bytes.Equal(proto[0:7], []byte("HTTP/1.")) { - errl.Printf("Invalid response status line: %s request %v\n", string(f[0]), r) - return errMalformResponse + return fmt.Errorf("invalid response status line: %s request %v", string(f[0]), r) } if proto[7] == '1' { rp.raw.Write(s) @@ -584,12 +662,11 @@ func parseResponse(sv *serverConn, r *Request, rp *Response) (err error) { // will be converted to chunked encoding rp.raw.WriteString(rp.genStatusLine()) } else { - errl.Printf("Response protocol not supported: %s\n", f[0]) - return errNotSupported + return fmt.Errorf("response protocol not supported: %s", f[0]) } if err = rp.parseHeader(reader, rp.raw, r.URL); err != nil { - errl.Printf("Reading response header: %v %v\n", err, r) + errl.Printf("parse response header: %v %s\n%s", err, rp, rp.Verbose()) return err } @@ -599,23 +676,21 @@ func parseResponse(sv *serverConn, r *Request, rp *Response) (err error) { return parseResponse(sv, r, rp) } - // Connection close, no content length specification - // Use chunked encoding to pass content back to client - if !rp.ConnectionKeepAlive && !rp.Chunking && rp.ContLen == -1 { + if rp.Chunking { + rp.raw.WriteString(fullHeaderTransferEncoding) + } else if rp.ContLen == -1 { + // No chunk, no content length, assume close to signal end. + rp.ConnectionKeepAlive = false if rp.hasBody(r.Method) { + // Connection close, no content length specification. + // Use chunked encoding to pass content back to client. debug.Println("add chunked encoding to close connection response", r, rp) - rp.raw.WriteString("Transfer-Encoding: chunked\r\n") + rp.raw.WriteString(fullHeaderTransferEncoding) } else { debug.Println("add content-length 0 to close connection response", r, rp) rp.raw.WriteString("Content-Length: 0\r\n") } } - // Check for invalid response - if !rp.hasBody(r.Method) && (rp.Chunking || rp.ContLen != -1) { - errl.Printf("response has no body, but with chunked/content-length set\n%s", - rp.Verbose()) - } - // Whether COW should respond with keep-alive depends on client request, // not server response. if r.ConnectionKeepAlive { diff --git a/http_test.go b/http_test.go index b560a2b7..330b9d1d 100644 --- a/http_test.go +++ b/http_test.go @@ -71,18 +71,16 @@ func TestParseHeader(t *testing.T) { "content-length: 64\r\n", &Header{ContLen: 64, Chunking: false, ConnectionKeepAlive: true}}, {"Connection: keep-alive\r\nKeep-Alive: timeout=10\r\nTransfer-Encoding: chunked\r\nTE: trailers\r\n\r\n", - "Transfer-Encoding: chunked\r\n", + "", &Header{ContLen: -1, Chunking: true, ConnectionKeepAlive: true, KeepAlive: 10 * time.Second}}, - /* - {"Connection: keep-alive\r\nKeep-Alive: max=5,\r\n timeout=10\r\n\r\n", // test multi-line header - "Connection: keep-alive\r\n", - &Header{ContLen: -1, Chunking: false, ConnectionKeepAlive: true, - KeepAlive: 10 * time.Second}}, - {"Connection: \r\n keep-alive\r\n\r\n", // test multi-line header - "Connection: keep-alive\r\n", - &Header{ContLen: -1, Chunking: false, ConnectionKeepAlive: true}}, - */ + {"Connection:\r\n keep-alive\r\nKeep-Alive: max=5,\r\n timeout=10\r\n\r\n", + "", + &Header{ContLen: -1, Chunking: false, ConnectionKeepAlive: true, + KeepAlive: 10 * time.Second}}, + {"Connection: \r\n close\r\nLong: line\r\n continued\r\n\tagain\r\n\r\n", + "Long: line continued again\r\n", + &Header{ContLen: -1, Chunking: false, ConnectionKeepAlive: false}}, } for _, td := range testData { var h Header diff --git a/install-cow.sh b/install-cow.sh index 0a4998df..aa8c021c 100755 --- a/install-cow.sh +++ b/install-cow.sh @@ -1,6 +1,6 @@ #!/bin/bash -version=0.7.6 +version=0.8 arch=`uname -m` case $arch in @@ -13,6 +13,9 @@ case $arch in "armv6l") arch="-armv6" ;; + "armv5tel") + arch="-armv5" + ;; *) echo "$arch currently has no precompiled binary" ;; diff --git a/main.go b/main.go index 2d9f6df4..0752e593 100644 --- a/main.go +++ b/main.go @@ -49,6 +49,7 @@ func main() { initAuth() initSiteStat() initPAC() // initPAC uses siteStat, so must init after site stat + initConnPool() initStat() diff --git a/parent_proxy.go b/parent_proxy.go index ee8a1596..0c16e985 100644 --- a/parent_proxy.go +++ b/parent_proxy.go @@ -11,7 +11,7 @@ import ( "strconv" ) -func connectByParentProxy(url *URL) (srvconn conn, err error) { +func connectByParentProxy(url *URL) (srvconn net.Conn, err error) { const baseFailCnt = 9 var skipped []int nproxy := len(parentProxy) @@ -43,12 +43,12 @@ func connectByParentProxy(url *URL) (srvconn conn, err error) { if len(parentProxy) != 0 { return } - return zeroConn, errNoParentProxy + return nil, errors.New("no parent proxy") } // proxyConnector is the interface that all parent proxies should support. type proxyConnector interface { - connect(*URL) (conn, error) + connect(*URL) (net.Conn, error) } type ParentProxy struct { @@ -62,7 +62,7 @@ func addParentProxy(pc proxyConnector) { parentProxy = append(parentProxy, ParentProxy{pc, 0}) } -func (pp *ParentProxy) connect(url *URL) (srvconn conn, err error) { +func (pp *ParentProxy) connect(url *URL) (srvconn net.Conn, err error) { const maxFailCnt = 30 srvconn, err = pp.proxyConnector.connect(url) if err != nil { @@ -95,6 +95,15 @@ type httpParent struct { authHeader []byte } +type httpConn struct { + net.Conn + parent *httpParent +} + +func (hc httpConn) String() string { + return "http parent proxy " + hc.parent.server +} + func newHttpParent(server string) *httpParent { return &httpParent{server: server} } @@ -104,14 +113,14 @@ func (hp *httpParent) initAuth(userPasswd string) { hp.authHeader = []byte(headerProxyAuthorization + ": Basic " + b64 + CRLF) } -func (hp *httpParent) connect(url *URL) (cn conn, err error) { +func (hp *httpParent) connect(url *URL) (net.Conn, error) { c, err := net.Dial("tcp", hp.server) if err != nil { errl.Printf("can't connect to http parent proxy %s for %s: %v\n", hp.server, url.HostPort, err) - return zeroConn, err + return nil, err } debug.Println("connected to:", url.HostPort, "via http parent proxy:", hp.server) - return conn{ctHttpProxyConn, c, hp}, nil + return httpConn{c, hp}, nil } // shadowsocks parent proxy @@ -120,6 +129,15 @@ type shadowsocksParent struct { cipher *ss.Cipher } +type shadowsocksConn struct { + net.Conn + parent *shadowsocksParent +} + +func (s shadowsocksConn) String() string { + return "shadowsocks proxy " + s.parent.server +} + // In order to use parent proxy in the order specified in the config file, we // insert an uninitialized proxy into parent proxy list, and initialize it // when all its config have been parsed. @@ -136,15 +154,15 @@ func (sp *shadowsocksParent) initCipher(passwd, method string) { sp.cipher = cipher } -func (sp *shadowsocksParent) connect(url *URL) (conn, error) { +func (sp *shadowsocksParent) connect(url *URL) (net.Conn, error) { c, err := ss.Dial(url.HostPort, sp.server, sp.cipher.Copy()) if err != nil { errl.Printf("create shadowsocks connection to %s through server %s failed %v\n", url.HostPort, sp.server, err) - return zeroConn, err + return nil, err } debug.Println("connected to:", url.HostPort, "via shadowsocks:", sp.server) - return conn{ctShadowctSocksConn, c, sp}, nil + return shadowsocksConn{c, sp}, nil } // For socks documentation, refer to rfc 1928 http://www.ietf.org/rfc/rfc1928.txt @@ -174,16 +192,25 @@ type socksParent struct { server string } +type socksConn struct { + net.Conn + parent socksParent +} + +func (s socksConn) String() string { + return "socks proxy " + s.parent.server +} + func newSocksParent(server string) socksParent { return socksParent{server} } -func (sp socksParent) connect(url *URL) (cn conn, err error) { +func (sp socksParent) connect(url *URL) (net.Conn, error) { c, err := net.Dial("tcp", sp.server) if err != nil { errl.Printf("can't connect to socks server %s for %s: %v\n", sp.server, url.HostPort, err) - return + return nil, err } hasErr := false defer func() { @@ -196,7 +223,7 @@ func (sp socksParent) connect(url *URL) (cn conn, err error) { if n, err = c.Write(socksMsgVerMethodSelection); n != 3 || err != nil { errl.Printf("sending ver/method selection msg %v n = %v\n", err, n) hasErr = true - return + return nil, err } // version/method selection @@ -205,13 +232,13 @@ func (sp socksParent) connect(url *URL) (cn conn, err error) { if err != nil { errl.Printf("read ver/method selection error %v\n", err) hasErr = true - return + return nil, err } if repBuf[0] != 5 || repBuf[1] != 0 { errl.Printf("socks ver/method selection reply error ver %d method %d", repBuf[0], repBuf[1]) hasErr = true - return + return nil, err } // debug.Println("Socks version selection done") @@ -221,7 +248,7 @@ func (sp socksParent) connect(url *URL) (cn conn, err error) { if err != nil { errl.Printf("should not happen, port error %v\n", port) hasErr = true - return + return nil, err } hostLen := len(host) @@ -244,7 +271,7 @@ func (sp socksParent) connect(url *URL) (cn conn, err error) { if n, err = c.Write(reqBuf); err != nil || n != bufLen { errl.Printf("send socks request err %v n %d\n", err, n) hasErr = true - return + return nil, err } // I'm not clear why the buffer is fixed at 10. The rfc document does not say this. @@ -256,27 +283,27 @@ func (sp socksParent) connect(url *URL) (cn conn, err error) { errl.Printf("read socks reply err %v n %d\n", err, n) } hasErr = true - return zeroConn, errors.New("Connection failed (by socks server). No such host?") + return nil, errors.New("Connection failed (by socks server). No such host?") } // debug.Printf("Socks reply length %d\n", n) if replyBuf[0] != 5 { errl.Printf("socks reply connect %s VER %d not supported\n", url.HostPort, replyBuf[0]) hasErr = true - return zeroConn, socksProtocolErr + return nil, socksProtocolErr } if replyBuf[1] != 0 { errl.Printf("socks reply connect %s error %s\n", url.HostPort, socksError[replyBuf[1]]) hasErr = true - return zeroConn, socksProtocolErr + return nil, socksProtocolErr } if replyBuf[3] != 1 { errl.Printf("socks reply connect %s ATYP %d\n", url.HostPort, replyBuf[3]) hasErr = true - return zeroConn, socksProtocolErr + return nil, socksProtocolErr } debug.Println("connected to:", url.HostPort, "via socks server:", sp.server) // Now the socket can be used to pass data. - return conn{ctSocksConn, c, sp}, nil + return socksConn{c, sp}, nil } diff --git a/proxy.go b/proxy.go index ee81750c..10e83c2f 100644 --- a/proxy.go +++ b/proxy.go @@ -8,7 +8,6 @@ import ( "github.com/cyfdecyf/leakybuf" "io" "net" - "reflect" "strings" "time" ) @@ -19,10 +18,12 @@ import ( // // For limits about URL and HTTP header size, refer to: // http://stackoverflow.com/questions/417142/what-is-the-maximum-length-of-a-url -// (URL usually are less than 2100 bytes.) +// "de facto limit of 2000 characters" // http://www.mnot.net/blog/2011/07/11/what_proxies_must_do -// (This says "URIs should be allowed at least 8000 octets, and HTTP headers -// (should have 4000 as an absolute minimum".) +// "URIs should be allowed at least 8000 octets, and HTTP headers should have +// 4000 as an absolute minimum". +// In practice, there are sites using cookies larger than 4096 bytes, +// e.g. www.fitbit.com. So set http buffer size to 8192 to be safe. const httpBufSize = 8192 // Hold at most 4MB memory as buffer for parsing http request/response and @@ -32,18 +33,12 @@ var httpBuf = leakybuf.NewLeakyBuf(512, httpBufSize) // If no keep-alive header in response, use this as the keep-alive value. const defaultServerConnTimeout = 15 * time.Second -// Close client connection if no new request received in some time. To prevent -// keeping too many idle (keep-alive) server connections, COW timeout read for -// the initial request line after clientConnTimeout, and closes client -// connection after clientMaxTimeoutCnt timeout. +// Close client connection if no new requests received in some time. // (On OS X, the default soft limit of open file descriptor is 256, which is -// very conservative and easy to cause problem if we are not careful.) -const clientConnTimeout = 5 * time.Second -const clientMaxTimeoutCnt = 2 -const fullKeepAliveHeader = "Keep-Alive: timeout=10\r\n" - -// Remove idle server connection every cleanServerInterval second. -const cleanServerInterval = 5 * time.Second +// very conservative and easy to cause problem if we are not careful to limit +// open fds.) +const clientConnTimeout = 15 * time.Second +const fullKeepAliveHeader = "Keep-Alive: timeout=15\r\n" // If client closed connection for HTTP CONNECT method in less then 1 second, // consider it as an ssl error. This is only effective for Chrome which will @@ -71,22 +66,14 @@ type Proxy struct { addrInPAC string // proxy server address to use in PAC } -type connType byte +var zeroTime time.Time -const ( - ctNilConn connType = iota - ctDirectConn - ctSocksConn - ctShadowctSocksConn - ctHttpProxyConn -) +type directConn struct { + net.Conn +} -var ctName = [...]string{ - ctNilConn: "nil", - ctDirectConn: "direct", - ctSocksConn: "socks5", - ctShadowctSocksConn: "shadowsocks", - ctHttpProxyConn: "http parent", +func (dc directConn) String() string { + return "direct connection" } type serverConnState byte @@ -97,20 +84,11 @@ const ( svStopped ) -type conn struct { - connType - net.Conn - creator proxyConnector -} - -var zeroConn conn -var zeroTime time.Time - type serverConn struct { - conn + net.Conn bufRd *bufio.Reader buf []byte // buffer for the buffered reader - url *URL + hostPort string state serverConnState willCloseOn time.Time siteInfo *VisitCnt @@ -118,29 +96,16 @@ type serverConn struct { } type clientConn struct { - net.Conn // connection to the proxy client - bufRd *bufio.Reader - buf []byte // buffer for the buffered reader - serverConn map[string]*serverConn // request serverConn, host:port as key - timeoutCnt int // number of timeouts reading requests - cleanedOn time.Time // time of last idle server clean up - proxy *Proxy + net.Conn // connection to the proxy client + bufRd *bufio.Reader + buf []byte // buffer for the buffered reader + proxy *Proxy } var ( - errTooManyRetry = errors.New("Too many retry") - errPageSent = errors.New("Error page has sent") - errShouldClose = errors.New("Error can only be handled by close connection") - errInternal = errors.New("Internal error") - errNoParentProxy = errors.New("No parent proxy") - errClientTimeout = errors.New("Read client request timeout") - - errChunkedEncode = errors.New("Invalid chunked encoding") - errMalformHeader = errors.New("Malformed HTTP header") - errMalformResponse = errors.New("Malformed HTTP response") - errNotSupported = errors.New("Not supported") - errBadRequest = errors.New("Bad request") - errAuthRequired = errors.New("Authentication requried") + errPageSent = errors.New("error page has sent") + errClientTimeout = errors.New("read client request timeout") + errAuthRequired = errors.New("authentication requried") ) func NewProxy(addr, addrInPAC string) *Proxy { @@ -162,17 +127,17 @@ func (py *Proxy) Serve(done chan byte) { } host, _, _ := net.SplitHostPort(py.addr) if host == "" || host == "0.0.0.0" { - info.Printf("COW proxy address %s, PAC url http://:%s/pac\n", py.addr, py.port) + info.Printf("COW %s proxy address %s, PAC url http://:%s/pac\n", version, py.addr, py.port) } else if py.addrInPAC == "" { - info.Printf("COW proxy address %s, PAC url http://%s/pac\n", py.addr, py.addr) + info.Printf("COW %s proxy address %s, PAC url http://%s/pac\n", version, py.addr, py.addr) } else { - info.Printf("COW proxy address %s, PAC url http://%s/pac\n", py.addr, py.addrInPAC) + info.Printf("COW %s proxy address %s, PAC url http://%s/pac\n", version, py.addr, py.addrInPAC) } for { conn, err := ln.Accept() if err != nil { - debug.Println("client connection:", err) + debug.Printf("proxy(%s) accept %v\n", ln.Addr(), err) continue } c := newClientConn(conn, py) @@ -183,11 +148,10 @@ func (py *Proxy) Serve(done chan byte) { func newClientConn(cli net.Conn, proxy *Proxy) *clientConn { buf := httpBuf.Get() c := &clientConn{ - Conn: cli, - serverConn: map[string]*serverConn{}, - buf: buf, - bufRd: bufio.NewReaderFromBuf(cli, buf), - proxy: proxy, + Conn: cli, + buf: buf, + bufRd: bufio.NewReaderFromBuf(cli, buf), + proxy: proxy, } if debug { debug.Printf("cli(%s) connected, total %d clients\n", @@ -197,19 +161,16 @@ func newClientConn(cli net.Conn, proxy *Proxy) *clientConn { } func (c *clientConn) releaseBuf() { - c.bufRd = nil - if c.buf != nil { + if c.bufRd != nil { // debug.Println("release client buffer") httpBuf.Put(c.buf) c.buf = nil + c.bufRd = nil } } func (c *clientConn) Close() { c.releaseBuf() - for _, sv := range c.serverConn { - sv.Close(c) - } if debug { debug.Printf("cli(%s) closed, total %d clients\n", c.RemoteAddr(), decCliCnt()) @@ -231,7 +192,7 @@ func (c *clientConn) serveSelfURL(r *Request) (err error) { return errPageSent } end: - sendErrorPage(c, "404 not found", "Page not found", "Handling request to proxy itself.") + sendErrorPage(c, "404 not found", "Page not found", "Serving request to COW proxy.") return errPageSent } @@ -241,11 +202,15 @@ func (c *clientConn) shouldRetry(r *Request, sv *serverConn, re error) bool { } err, _ := re.(RetryError) if !r.responseNotSent() { - debug.Printf("%v has sent some response, can't retry\n", r) + if debug { + debug.Printf("cli(%s) has sent some response, can't retry %v\n", c.RemoteAddr(), r) + } return false } if r.partial { - debug.Printf("%v partial request, can't retry\n", r) + if debug { + debug.Printf("cli(%s) partial request, can't retry %v\n", c.RemoteAddr(), r) + } sendErrorPage(c, "502 partial request", err.Error(), genErrMsg(r, sv, "Request is too large to hold in buffer, can't retry. "+ "Refresh to retry may work.")) @@ -263,7 +228,7 @@ func (c *clientConn) shouldRetry(r *Request, sv *serverConn, re error) bool { r.tryCnt = 0 return true } - debug.Printf("Can't retry %v tryCnt=%d\n", r, r.tryCnt) + debug.Printf("cli(%s) can't retry %v tryCnt=%d\n", c.RemoteAddr(), r, r.tryCnt) sendErrorPage(c, "502 retry failed", "Can't finish HTTP request", genErrMsg(r, sv, "Has tried several times.")) return false @@ -272,6 +237,10 @@ func (c *clientConn) shouldRetry(r *Request, sv *serverConn, re error) bool { } func dbgPrintRq(c *clientConn, r *Request) { + if r.Trailer { + errl.Printf("cli(%s) request %s has Trailer header\n%s", + c.RemoteAddr(), r, r.Verbose()) + } if dbgRq { if verbose { dbgRq.Printf("cli(%s) request %s\n%s", c.RemoteAddr(), r, r.Verbose()) @@ -281,6 +250,12 @@ func dbgPrintRq(c *clientConn, r *Request) { } } +type SinkWriter struct{} + +func (s SinkWriter) Write(p []byte) (int, error) { + return len(p), nil +} + func (c *clientConn) serve() { var r Request var rp Response @@ -288,7 +263,6 @@ func (c *clientConn) serve() { var err error var authed bool - var authCnt int defer func() { r.releaseBuf() @@ -301,10 +275,6 @@ func (c *clientConn) serve() { if c.bufRd == nil || c.buf == nil { panic("client read buffer nil") } - // clean up idle server connection before waiting for client request - if c.shouldCleanServerConn() { - c.cleanServerConn() - } if err = parseRequest(c, &r); err != nil { if debug { @@ -317,19 +287,15 @@ func (c *clientConn) serve() { sendErrorPage(c, "404 Bad request", "Bad request", err.Error()) return } - c.timeoutCnt++ - if c.timeoutCnt < clientMaxTimeoutCnt { - c.cleanServerConn() - continue - } sendErrorPage(c, statusRequestTimeout, statusRequestTimeout, "Your browser didn't send a complete request in time.") return } - // next getRequest should start with timeout count 0 - c.timeoutCnt = 0 dbgPrintRq(c, &r) + // PAC may leak frequently visited sites information. But if cow + // requires authentication for PAC, some clients may not be able + // handle it. (e.g. Proxy SwitchySharp extension on Chrome.) if isSelfURL(r.URL.HostPort) { if err = c.serveSelfURL(&r); err != nil { return @@ -338,20 +304,22 @@ func (c *clientConn) serve() { } if auth.required && !authed { - if authCnt > 5 { - return - } if err = Authenticate(c, &r); err != nil { - if err == errAuthRequired { - authCnt++ - continue - } else { - return - } + errl.Printf("cli(%s) %v\n", c.RemoteAddr(), err) + // Request may have body. To make things simple, close + // connection so we don't need to skip request body before + // reading the next request. + return } authed = true } + if r.isConnect && !config.TunnelAllowedPort[r.URL.Port] { + sendErrorPage(c, statusForbidden, "Forbidden tunnel port", + genErrMsg(&r, nil, "Please contact proxy admin.")) + return + } + if r.ExpectContinue { sendErrorPage(c, statusExpectFailed, "Expect header not supported", "Please contact COW's developer if you see this.") @@ -365,24 +333,30 @@ func (c *clientConn) serve() { retry: r.tryOnce() if bool(debug) && r.isRetry() { - errl.Printf("cli(%s) retry request tryCnt=%d %v\n", c.RemoteAddr(), r.tryCnt, &r) + debug.Printf("cli(%s) retry request tryCnt=%d %v\n", c.RemoteAddr(), r.tryCnt, &r) } if sv, err = c.getServerConn(&r); err != nil { - // debug.Printf("Failed to get serverConn for %s %v\n", c.RemoteAddr(), r) + if debug { + debug.Printf("cli(%s) failed to get server conn %v\n", c.RemoteAddr(), &r) + } // Failed connection will send error page back to the client. // For CONNECT, the client read buffer is released in copyClient2Server, // so can't go back to getRequest. if err == errPageSent && !r.isConnect { + if r.hasBody() { + // skip request body + debug.Printf("cli(%s) skip request body %v\n", c.RemoteAddr(), &r) + sendBody(SinkWriter{}, c.bufRd, int(r.ContLen), r.Chunking) + } continue } return } if r.isConnect { + // server connection will be closed in doConnect err = sv.doConnect(&r, c) - sv.Close(c) if c.shouldRetry(&r, sv, err) { - // connection for CONNECT is not reused, no need to remove goto retry } // debug.Printf("doConnect %s to %s done\n", c.RemoteAddr(), r.URL.HostPort) @@ -390,18 +364,37 @@ func (c *clientConn) serve() { } if err = sv.doRequest(c, &r, &rp); err != nil { - c.removeServerConn(sv) + // For client I/O error, we can actually put server connection to + // pool. But let's make thing simple for now. + sv.Close() if c.shouldRetry(&r, sv, err) { goto retry - } else if err == errPageSent { + } else if err == errPageSent && (!r.hasBody() || r.hasSent()) { + // Can only continue if request has no body, or request body + // has been read. continue } return } + // Put server connection to pool, so other clients can use it. + if rp.ConnectionKeepAlive { + if debug { + debug.Printf("cli(%s) connPool put %s", c.RemoteAddr(), sv.hostPort) + } + // If the server connection is not going to be used soon, + // release buffer before putting back to pool can save memory. + sv.releaseBuf() + connPool.Put(sv) + } else { + if debug { + debug.Printf("cli(%s) server %s close conn\n", c.RemoteAddr(), sv.hostPort) + } + sv.Close() + } if !r.ConnectionKeepAlive { if debug { - debug.Println("cli(%s) close connection", c.RemoteAddr()) + debug.Printf("cli(%s) close connection\n", c.RemoteAddr()) } return } @@ -412,8 +405,8 @@ func genErrMsg(r *Request, sv *serverConn, what string) string { if sv == nil { return fmt.Sprintf("

HTTP Request %v

%s

", r, what) } - return fmt.Sprintf("

HTTP Request %v

%s

Using %s connection.

", - r, what, ctName[sv.connType]) + return fmt.Sprintf("

HTTP Request %v

%s

Using %s.

", + r, what, sv.Conn) } func (c *clientConn) handleBlockedRequest(r *Request, err error) error { @@ -422,11 +415,11 @@ func (c *clientConn) handleBlockedRequest(r *Request, err error) error { } func (c *clientConn) handleServerReadError(r *Request, sv *serverConn, err error, msg string) error { + if debug { + debug.Printf("cli(%s) server read error %s %v %v\n", c.RemoteAddr(), msg, err, r) + } var errMsg string if err == io.EOF { - if debug { - debug.Printf("cli(%s) %s read from server EOF\n", c.RemoteAddr(), msg) - } return RetryError{err} } if sv.maybeFake() && maybeBlocked(err) { @@ -437,8 +430,8 @@ func (c *clientConn) handleServerReadError(r *Request, sv *serverConn, err error sendErrorPage(c, "502 read error", err.Error(), errMsg) return errPageSent } - errl.Println(msg+" unhandled server read error:", err, reflect.TypeOf(err), r) - return errShouldClose + errl.Printf("cli(%s) unhandled server read error %s %v %v\n", c.RemoteAddr(), msg, err, r) + return err } func (c *clientConn) handleServerWriteError(r *Request, sv *serverConn, err error, msg string) error { @@ -451,6 +444,10 @@ func (c *clientConn) handleServerWriteError(r *Request, sv *serverConn, err erro } func dbgPrintRep(c *clientConn, r *Request, rp *Response) { + if rp.Trailer { + errl.Printf("cli(%s) response %s has Trailer header\n%s", + c.RemoteAddr(), rp, rp.Verbose()) + } if dbgRep { if verbose { dbgRep.Printf("cli(%s) response %s %s\n%s", @@ -465,9 +462,7 @@ func dbgPrintRep(c *clientConn, r *Request, rp *Response) { func (c *clientConn) readResponse(sv *serverConn, r *Request, rp *Response) (err error) { sv.initBuf() defer func() { - if rp != nil { - rp.releaseBuf() - } + rp.releaseBuf() }() /* @@ -484,8 +479,9 @@ func (c *clientConn) readResponse(sv *serverConn, r *Request, rp *Response) (err */ if err = parseResponse(sv, r, rp); err != nil { - return c.handleServerReadError(r, sv, err, "Parse response from server.") + return c.handleServerReadError(r, sv, err, "parse response") } + dbgPrintRep(c, r, rp) // After have received the first reponses from the server, we consider // ther server as real instead of fake one caused by wrong DNS reply. So // don't time out later. @@ -496,12 +492,11 @@ func (c *clientConn) readResponse(sv *serverConn, r *Request, rp *Response) (err if _, err = c.Write(rp.rawResponse()); err != nil { return err } - dbgPrintRep(c, r, rp) rp.releaseBuf() if rp.hasBody(r.Method) { - if err = sendBody(c, sv, nil, rp); err != nil { + if err = sendBody(c, sv.bufRd, int(rp.ContLen), rp.Chunking); err != nil { if debug { debug.Printf("cli(%s) send body %v\n", c.RemoteAddr(), err) } @@ -514,14 +509,12 @@ func (c *clientConn) readResponse(sv *serverConn, r *Request, rp *Response) (err // The client connection will be closed to indicate this error. // Proxy can't send error page here because response header has // been sent. - errl.Println("unexpected EOF reading body from server", r) + return fmt.Errorf("read response body unexpected EOF %v", rp) } else if isErrOpRead(err) { - return c.handleServerReadError(r, sv, err, "Read response body from server.") - } else if isErrOpWrite(err) { - return err + return c.handleServerReadError(r, sv, err, "read response body") } - errl.Println("sendBody unknown network op error", reflect.TypeOf(err), r) - return errShouldClose + // errl.Println("sendBody unknown network op error", reflect.TypeOf(err), r) + return err } } r.state = rsDone @@ -530,77 +523,60 @@ func (c *clientConn) readResponse(sv *serverConn, r *Request, rp *Response) (err debug.Printf("[Finished] %v request %s %s\n", c.RemoteAddr(), r.Method, r.URL) } */ - var remoteAddr string // avoid evaluating c.RemoteAddr() in the following debug call - if debug { - remoteAddr = c.RemoteAddr().String() - } if rp.ConnectionKeepAlive { if rp.KeepAlive == time.Duration(0) { sv.willCloseOn = time.Now().Add(defaultServerConnTimeout) } else { - debug.Printf("cli(%s) server %s keep-alive %v\n", - remoteAddr, sv.url.HostPort, rp.KeepAlive) + // debug.Printf("cli(%s) server %s keep-alive %v\n", c.RemoteAddr(), sv.hostPort, rp.KeepAlive) sv.willCloseOn = time.Now().Add(rp.KeepAlive) } - } else { - debug.Printf("cli(%s) server %s close connection\n", - remoteAddr, sv.url.HostPort) - c.removeServerConn(sv) } return } -func (c *clientConn) shouldCleanServerConn() bool { - return len(c.serverConn) > 0 && - time.Now().Sub(c.cleanedOn) > cleanServerInterval -} +func (c *clientConn) getServerConn(r *Request) (*serverConn, error) { + // For CONNECT method, always create new connection. + if r.isConnect { + return c.createServerConn(r) + } -// Remove all maybe closed server connection -func (c *clientConn) cleanServerConn() { - now := time.Now() - c.cleanedOn = now - for _, sv := range c.serverConn { - if now.After(sv.willCloseOn) { - c.removeServerConn(sv) + sv := connPool.Get(r.URL.HostPort) + if sv != nil { + if debug { + debug.Printf("cli(%s) connPool get %s\n", c.RemoteAddr(), r.URL.HostPort) } + return sv, nil } if debug { - debug.Printf("cli(%s) close idle connections, remains %d\n", - c.RemoteAddr(), len(c.serverConn)) - } -} - -func (c *clientConn) getServerConn(r *Request) (sv *serverConn, err error) { - sv, ok := c.serverConn[r.URL.HostPort] - if ok && sv.mayBeClosed() { - // debug.Printf("Connection to %s maybe closed\n", sv.url.HostPort) - c.removeServerConn(sv) - ok = false - } - if !ok { - sv, err = c.createServerConn(r) + debug.Printf("cli(%s) connPool no conn %s", c.RemoteAddr(), r.URL.HostPort) } - return + return c.createServerConn(r) } -func (c *clientConn) removeServerConn(sv *serverConn) { - sv.Close(c) - delete(c.serverConn, sv.url.HostPort) -} - -func connectDirect(url *URL, siteInfo *VisitCnt) (conn, error) { - to := dialTimeout - if siteInfo.OnceBlocked() && to >= defaultDialTimeout { - to = minDialTimeout +func connectDirect(url *URL, siteInfo *VisitCnt) (net.Conn, error) { + var c net.Conn + var err error + if siteInfo.AlwaysDirect() { + c, err = net.Dial("tcp", url.HostPort) + } else { + to := dialTimeout + if siteInfo.OnceBlocked() && to >= defaultDialTimeout { + // If once blocked, decrease timeout to switch to parent proxy faster. + to = minDialTimeout + } else if siteInfo.AsDirect() { + // If usually can be accessed directly, increase timeout to avoid + // problems when network condition is bad. + to = maxTimeout + } + c, err = net.DialTimeout("tcp", url.HostPort, to) } - c, err := net.DialTimeout("tcp", url.HostPort, to) if err != nil { // Time out is very likely to be caused by GFW debug.Printf("error direct connect to: %s %v\n", url.HostPort, err) - return zeroConn, err + return nil, err } // debug.Println("directly connected to", url.HostPort) - return conn{ctDirectConn, c, nil}, nil + return directConn{c}, nil } func isErrTimeout(err error) bool { @@ -616,13 +592,13 @@ func maybeBlocked(err error) bool { // Connect to requested server according to whether it's visit count. // If direct connection fails, try parent proxies. -func (c *clientConn) connect(r *Request, siteInfo *VisitCnt) (srvconn conn, err error) { +func (c *clientConn) connect(r *Request, siteInfo *VisitCnt) (srvconn net.Conn, err error) { var errMsg string if config.AlwaysProxy { if srvconn, err = connectByParentProxy(r.URL); err == nil { return } - errMsg = genErrMsg(r, nil, "Parent proxy connection failed, always using parent proxy.") + errMsg = genErrMsg(r, nil, "Parent proxy connection failed, always use parent proxy.") goto fail } if siteInfo.AsBlocked() && hasParentProxy { @@ -667,15 +643,19 @@ func (c *clientConn) connect(r *Request, siteInfo *VisitCnt) (srvconn conn, err var socksErr error if srvconn, socksErr = connectByParentProxy(r.URL); socksErr == nil { c.handleBlockedRequest(r, err) - debug.Println("direct connection failed, use parent proxy for", r) + if debug { + debug.Printf("cli(%s) direct connection failed, use parent proxy for %v\n", + c.RemoteAddr(), r) + } return srvconn, nil } - errMsg = genErrMsg(r, nil, "Direct and parent proxy connection failed, maybe blocked site.") + errMsg = genErrMsg(r, nil, + "Direct and parent proxy connection failed, maybe blocked site.") } fail: sendErrorPage(c, "504 Connection failed", err.Error(), errMsg) - return zeroConn, errPageSent + return nil, errPageSent } func (c *clientConn) createServerConn(r *Request) (*serverConn, error) { @@ -684,35 +664,29 @@ func (c *clientConn) createServerConn(r *Request) (*serverConn, error) { if err != nil { return nil, err } - sv := newServerConn(srvconn, r.URL, siteInfo) - if r.isConnect { - // Don't put connection for CONNECT method for reuse - return sv, nil - } - c.serverConn[sv.url.HostPort] = sv + sv := newServerConn(srvconn, r.URL.HostPort, siteInfo) if debug { debug.Printf("cli(%s) connected to %s %d concurrent connections\n", - c.RemoteAddr(), sv.url.HostPort, incSrvConnCnt(sv.url.HostPort)) + c.RemoteAddr(), sv.hostPort, incSrvConnCnt(sv.hostPort)) } - // client will connect to differnet servers in a single proxy connection - // debug.Printf("serverConn to for client %v %v\n", c.RemoteAddr(), c.serverConn) return sv, nil } // Should call initBuf before reading http response from server. This allows // us not init buf for connect method which does not need to parse http // respnose. -func newServerConn(c conn, url *URL, siteInfo *VisitCnt) *serverConn { +func newServerConn(c net.Conn, hostPort string, siteInfo *VisitCnt) *serverConn { sv := &serverConn{ - conn: c, - url: url, + Conn: c, + hostPort: hostPort, siteInfo: siteInfo, } return sv } func (sv *serverConn) isDirect() bool { - return sv.connType == ctDirectConn + _, ok := sv.Conn.(directConn) + return ok } func (sv *serverConn) updateVisit() { @@ -734,16 +708,20 @@ func (sv *serverConn) initBuf() { } } -func (sv *serverConn) Close(c *clientConn) error { - sv.bufRd = nil - if sv.buf != nil { +func (sv *serverConn) releaseBuf() { + if sv.bufRd != nil { // debug.Println("release server buffer") httpBuf.Put(sv.buf) sv.buf = nil + sv.bufRd = nil } +} + +func (sv *serverConn) Close() error { + sv.releaseBuf() if debug { - debug.Printf("cli(%s) close connection to %s remains %d concurrent connections\n", - c.RemoteAddr(), sv.url.HostPort, decSrvConnCnt(sv.url.HostPort)) + debug.Printf("close connection to %s remains %d concurrent connections\n", + sv.hostPort, decSrvConnCnt(sv.hostPort)) } return sv.Conn.Close() } @@ -760,16 +738,17 @@ func setConnReadTimeout(cn net.Conn, d time.Duration, msg string) { func unsetConnReadTimeout(cn net.Conn, msg string) { if err := cn.SetReadDeadline(zeroTime); err != nil { - errl.Println("Unset readtimeout:", msg, err) + // It's possible that conn has been closed, so use debug log. + debug.Println("Unset readtimeout:", msg, err) } } -// setReadTimeout will only set timeout if the server connection maybe fake. -// In case it's not fake, this will unset timeout. func (sv *serverConn) setReadTimeout(msg string) { to := readTimeout if sv.siteInfo.OnceBlocked() && to > defaultReadTimeout { to = minReadTimeout + } else if sv.siteInfo.AsDirect() { + to = maxTimeout } setConnReadTimeout(sv.Conn, to, msg) } @@ -860,6 +839,7 @@ func newServerWriter(r *Request, sv *serverConn) *serverWriter { } // Write to server, store written data in request buffer if necessary. +// We have to save request body in order to retry request. // FIXME: too tighly coupled with Request. func (sw *serverWriter) Write(p []byte) (int, error) { if sw.rq.raw == nil { @@ -868,27 +848,28 @@ func (sw *serverWriter) Write(p []byte) (int, error) { // Avoid using too much memory to hold request body. If a request is // not buffered completely, COW can't retry and can release memory // immediately. - debug.Println("request body too large, not buffering any more") + debug.Println(sw.rq, "request body too large, not buffering any more") sw.rq.releaseBuf() sw.rq.partial = true } else if sw.rq.responseNotSent() { sw.rq.raw.Write(p) - } else { // has sent response + } else { // has sent response, happens when saving data for CONNECT method sw.rq.releaseBuf() } return sw.sv.Write(p) } -func copyClient2Server(c *clientConn, sv *serverConn, r *Request, srvStopped notification, done chan byte) (err error) { +func copyClient2Server(c *clientConn, sv *serverConn, r *Request, srvStopped notification, done chan struct{}) (err error) { // sv.maybeFake may change during execution in this function. // So need a variable to record the whether timeout is set. deadlineIsSet := false defer func() { if deadlineIsSet { - // maybe need to retry, should unset timeout here because + // May need to retry, unset timeout here to avoid read client + // timeout on retry. Note c.Conn maybe closed when calling this. unsetConnReadTimeout(c.Conn, "cli->srv after err") } - done <- 1 + close(done) }() var n int @@ -974,11 +955,14 @@ var connEstablished = []byte("HTTP/1.1 200 Tunnel established\r\n\r\n") func (sv *serverConn) doConnect(r *Request, c *clientConn) (err error) { r.state = rsCreated - if sv.connType == ctHttpProxyConn { - // debug.Printf("%s Sending CONNECT request to http proxy server\n", c.RemoteAddr()) - if err = sv.sendHTTPProxyRequest(r, c); err != nil { + _, isHttpConn := sv.Conn.(httpConn) + if isHttpConn { + if debug { + debug.Printf("cli(%s) send CONNECT request to http parent\n", c.RemoteAddr()) + } + if err = sv.sendHTTPProxyRequestHeader(r, c); err != nil { if debug { - debug.Printf("cli(%s) error sending CONNECT request to http proxy server: %v\n", + debug.Printf("cli(%s) error send CONNECT request to http proxy server: %v\n", c.RemoteAddr(), err) } return err @@ -987,7 +971,7 @@ func (sv *serverConn) doConnect(r *Request, c *clientConn) (err error) { // debug.Printf("send connection confirmation to %s->%s\n", c.RemoteAddr(), r.URL.HostPort) if _, err = c.Write(connEstablished); err != nil { if debug { - debug.Printf("cli(%s) error sending 200 Connecion established: %v\n", + debug.Printf("cli(%s) error send 200 Connecion established: %v\n", c.RemoteAddr(), err) } return err @@ -995,12 +979,14 @@ func (sv *serverConn) doConnect(r *Request, c *clientConn) (err error) { } var cli2srvErr error - done := make(chan byte, 1) + done := make(chan struct{}) srvStopped := newNotification() go func() { // debug.Printf("doConnect: cli(%s)->srv(%s)\n", c.RemoteAddr(), r.URL.HostPort) cli2srvErr = copyClient2Server(c, sv, r, srvStopped, done) - sv.Close(c) // close sv to force read from server in copyServer2Client return + // Close sv to force read from server in copyServer2Client return. + // Note: there's no other code closing the server connection for CONNECT. + sv.Close() }() // debug.Printf("doConnect: srv(%s)->cli(%s)\n", r.URL.HostPort, c.RemoteAddr()) @@ -1019,25 +1005,26 @@ func (sv *serverConn) doConnect(r *Request, c *clientConn) (err error) { return } -func (sv *serverConn) sendHTTPProxyRequest(r *Request, c *clientConn) (err error) { +func (sv *serverConn) sendHTTPProxyRequestHeader(r *Request, c *clientConn) (err error) { if _, err = sv.Write(r.proxyRequestLine()); err != nil { return c.handleServerWriteError(r, sv, err, - "sending proxy request line to http parent") + "send proxy request line to http parent") } // Add authorization header for parent http proxy - hp, ok := sv.creator.(*httpParent) + hc, ok := sv.Conn.(httpConn) if !ok { panic("must be http parent connection") } - if hp.authHeader != nil { - if _, err = sv.Write(hp.authHeader); err != nil { + if hc.parent.authHeader != nil { + if _, err = sv.Write(hc.parent.authHeader); err != nil { return c.handleServerWriteError(r, sv, err, - "sending proxy authorization header to http parent") + "send proxy authorization header to http parent") } } + // When retry, body is in raw buffer. if _, err = sv.Write(r.rawHeaderBody()); err != nil { return c.handleServerWriteError(r, sv, err, - "sending proxy request header to http parent") + "send proxy request header to http parent") } /* if bool(dbgRq) && verbose { @@ -1047,10 +1034,11 @@ func (sv *serverConn) sendHTTPProxyRequest(r *Request, c *clientConn) (err error return } -func (sv *serverConn) sendRequest(r *Request, c *clientConn) (err error) { +func (sv *serverConn) sendRequestHeader(r *Request, c *clientConn) (err error) { // Send request to the server - if sv.connType == ctHttpProxyConn { - return sv.sendHTTPProxyRequest(r, c) + _, isHttpConn := sv.Conn.(httpConn) + if isHttpConn { + return sv.sendHTTPProxyRequestHeader(r, c) } /* if bool(debug) && verbose { @@ -1058,7 +1046,28 @@ func (sv *serverConn) sendRequest(r *Request, c *clientConn) (err error) { } */ if _, err = sv.Write(r.rawRequest()); err != nil { - err = c.handleServerWriteError(r, sv, err, "sending request to server") + err = c.handleServerWriteError(r, sv, err, "send request to server") + } + return +} + +func (sv *serverConn) sendRequestBody(r *Request, c *clientConn) (err error) { + // Send request body. If this is retry, r.raw contains request body and is + // sent while sending raw request. + if !r.hasBody() || r.isRetry() { + return + } + + err = sendBody(newServerWriter(r, sv), c.bufRd, int(r.ContLen), r.Chunking) + if err != nil { + errl.Printf("cli(%s) send request body error %v %s\n", c.RemoteAddr(), err, r) + if isErrOpWrite(err) { + err = c.handleServerWriteError(r, sv, err, "send request body") + } + return + } + if debug { + debug.Printf("cli(%s) request body sent %s\n", c.RemoteAddr(), r) } return } @@ -1066,39 +1075,21 @@ func (sv *serverConn) sendRequest(r *Request, c *clientConn) (err error) { // Do HTTP request other that CONNECT func (sv *serverConn) doRequest(c *clientConn, r *Request, rp *Response) (err error) { r.state = rsCreated - if err = sv.sendRequest(r, c); err != nil { + if err = sv.sendRequestHeader(r, c); err != nil { return } - - // Send request body. If this is retry, r.raw contains request body and is - // sent while sending request. - if !r.isRetry() && (r.Chunking || r.ContLen > 0) { - // Message body in request is signaled by the inclusion of a Content- - // Length or Transfer-Encoding header. Refer to http://stackoverflow.com/a/299696/306935 - if err = sendBody(c, sv, r, nil); err != nil { - if err == io.EOF && isErrOpRead(err) { - errl.Println("EOF reading request body from client", r) - } else if isErrOpWrite(err) { - err = c.handleServerWriteError(r, sv, err, "sending request body") - } else { - errl.Println("reading request body:", err) - } - return - } - if debug { - debug.Printf("cli(%s) %s request body sent\n", c.RemoteAddr(), r) - } + if err = sv.sendRequestBody(r, c); err != nil { + return } r.state = rsSent - err = c.readResponse(sv, r, rp) - if err == nil { + if err = c.readResponse(sv, r, rp); err == nil { sv.updateVisit() } return err } // Send response body if header specifies content length -func sendBodyWithContLen(r *bufio.Reader, w io.Writer, contLen int) (err error) { +func sendBodyWithContLen(w io.Writer, r *bufio.Reader, contLen int) (err error) { // debug.Println("Sending body with content length", contLen) if contLen == 0 { return @@ -1109,46 +1100,85 @@ func sendBodyWithContLen(r *bufio.Reader, w io.Writer, contLen int) (err error) return } +// Use this function until we find Trailer headers actually in use. +func skipTrailer(r *bufio.Reader) error { + // It's possible to get trailer headers, but the body will always end with + // a line with just CRLF. + for { + s, err := r.ReadSlice('\n') + if err != nil { + errl.Println("skip trailer:", err) + return err + } + if len(s) == 2 && s[0] == '\r' && s[1] == '\n' { + return nil + } + errl.Printf("skip trailer: %#v", string(s)) + if len(s) == 1 || len(s) == 2 { + return fmt.Errorf("malformed chunk body end: %#v", string(s)) + } + } +} + +func skipCRLF(r *bufio.Reader) (err error) { + var buf [2]byte + if _, err = io.ReadFull(r, buf[:]); err != nil { + errl.Println("skip chunk body end:", err) + return + } + if buf[0] != '\r' || buf[1] != '\n' { + return fmt.Errorf("malformed chunk body end: %#v", string(buf[:])) + } + return +} + // Send response body if header specifies chunked encoding. rdSize specifies // the size of each read on Reader, it should be set to be the buffer size of // the Reader, this parameter is added for testing. -func sendBodyChunked(r *bufio.Reader, w io.Writer, rdSize int) (err error) { +func sendBodyChunked(w io.Writer, r *bufio.Reader, rdSize int) (err error) { // debug.Println("Sending chunked body") for { var s []byte // Read chunk size line, ignore chunk extension if any. if s, err = r.PeekSlice('\n'); err != nil { - errl.Println("peeking chunk size:", err) + errl.Println("peek chunk size:", err) return } - // debug.Printf("Chunk size line %s\n", s) smid := bytes.IndexByte(s, ';') if smid == -1 { smid = len(s) + } else { + // use error log to find usage of chunk extension + errl.Printf("got chunk extension: %s\n", s) } var size int64 if size, err = ParseIntFromBytes(TrimSpace(s[:smid]), 16); err != nil { errl.Println("chunk size invalid:", err) return } - // end of chunked data. As we remove trailer header in request sending - // to server, there should be no trailer in response. - // TODO: Is it possible for client request body to have trailers in it? + if debug { + // To debug getting malformed response status line with "0\r\n". + if c, ok := w.(*clientConn); ok { + debug.Printf("cli(%s) chunk size %d %#v\n", c.RemoteAddr(), size, string(s)) + } + } if size == 0 { r.Skip(len(s)) - skipCRLF(r) + if err = skipCRLF(r); err != nil { + return + } if _, err = w.Write([]byte(chunkEnd)); err != nil { - debug.Println("sending chunk ending:", err) + debug.Println("send chunk ending:", err) } return } - // The spec section 19.3 only suggest toleranting single LF for + // RFC 2616 19.3 only suggest tolerating single LF for // headers, not for chunked encoding. So assume the server will send // CRLF. If not, the following parse int may find errors. total := len(s) + int(size) + 2 // total data size for this chunk, including ending CRLF // PeekSlice will not advance reader, so we can just copy total sized data. if err = copyN(w, r, total, rdSize); err != nil { - debug.Println("copying chunked data:", err) + debug.Println("copy chunked data:", err) return } } @@ -1157,7 +1187,7 @@ func sendBodyChunked(r *bufio.Reader, w io.Writer, rdSize int) (err error) { const CRLF = "\r\n" const chunkEnd = "0\r\n\r\n" -func sendBodySplitIntoChunk(r *bufio.Reader, w io.Writer) (err error) { +func sendBodySplitIntoChunk(w io.Writer, r *bufio.Reader) (err error) { // debug.Printf("sendBodySplitIntoChunk called\n") var b []byte for { @@ -1169,7 +1199,7 @@ func sendBodySplitIntoChunk(r *bufio.Reader, w io.Writer) (err error) { // debug.Println("end chunked encoding") _, err = w.Write([]byte(chunkEnd)) if err != nil { - debug.Println("Write chunk end 0") + debug.Println("write chunk end 0", err) } return } @@ -1179,56 +1209,34 @@ func sendBodySplitIntoChunk(r *bufio.Reader, w io.Writer) (err error) { chunkSize := []byte(fmt.Sprintf("%x\r\n", len(b))) if _, err = w.Write(chunkSize); err != nil { - debug.Printf("writing chunk size %v\n", err) + debug.Printf("write chunk size %v\n", err) return } if _, err = w.Write(b); err != nil { - debug.Println("writing chunk data:", err) + debug.Println("write chunk data:", err) return } if _, err = w.Write([]byte(CRLF)); err != nil { - debug.Println("writing chunk ending CRLF:", err) + debug.Println("write chunk ending CRLF:", err) return } } } -// Send message body. If req is not nil, read from client, send to server. If -// rp is not nil, the direction is the oppisite. -func sendBody(c *clientConn, sv *serverConn, req *Request, rp *Response) (err error) { - var contLen int - var chunk bool - var bufRd *bufio.Reader - var w io.Writer - - if rp != nil { // read responses from server, write to client - w = c - bufRd = sv.bufRd - contLen = int(rp.ContLen) - chunk = rp.Chunking - } else if req != nil { // read request body from client, send to server - // The server connection may have been closed, need to retry request in that case. - // So always need to save request body. - w = newServerWriter(req, sv) - bufRd = c.bufRd - contLen = int(req.ContLen) - chunk = req.Chunking - } else { - panic("sendBody must have either request or response not nil") - } - +// Send message body. +func sendBody(w io.Writer, bufRd *bufio.Reader, contLen int, chunk bool) (err error) { // chunked encoding has precedence over content length - // COW does not sanitize response header, but should correctly handle it + // COW does not sanitize response header, but can correctly handle it if chunk { - err = sendBodyChunked(bufRd, w, httpBufSize) + err = sendBodyChunked(w, bufRd, httpBufSize) } else if contLen >= 0 { - err = sendBodyWithContLen(bufRd, w, contLen) + // It's possible to have content length 0 if server response has no + // body. + err = sendBodyWithContLen(w, bufRd, int(contLen)) } else { - if req != nil { - errl.Println("client request with body but no length or chunked encoding specified.") - return errBadRequest - } - err = sendBodySplitIntoChunk(bufRd, w) + // Must be reading server response here, because sendBody is called in + // reading response iff chunked or content length > 0. + err = sendBodySplitIntoChunk(w, bufRd) } return } diff --git a/proxy_test.go b/proxy_test.go index 5e63f770..abba748a 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -22,14 +22,19 @@ func TestSendBodyChunked(t *testing.T) { */ } + // supress error log when finding chunk extension + errl = false + defer func() { + errl = true + }() // use different reader buffer size to test for both all buffered and partially buffered chunk sizeArr := []int{32, 64, 128} for _, size := range sizeArr { for _, td := range testData { r := bufio.NewReaderSize(strings.NewReader(td.raw), size) - var w bytes.Buffer + w := new(bytes.Buffer) - if err := sendBodyChunked(r, &w, size); err != nil { + if err := sendBodyChunked(w, r, size); err != nil { t.Fatalf("sent data %q err: %v\n", w.Bytes(), err) } if td.want == "" { diff --git a/script/build.sh b/script/build.sh index 9f95226b..7e1986d4 100755 --- a/script/build.sh +++ b/script/build.sh @@ -8,19 +8,27 @@ echo "creating cow binary version $version" mkdir -p bin build() { local name - local GOOS - local GOARCH + local goos + local goarch + local goarm + local cgo + + goos="GOOS=$1" + goarch="GOARCH=$2" + if [[ $3 == "linux-armv5" ]]; then + goarm="GOARM=5" + fi if [[ $1 == "darwin" ]]; then # Enable CGO for OS X so change network location will not cause problem. - export CGO_ENABLED=1 + cgo="CGO_ENABLED=1" else - export CGO_ENABLED=0 + cgo="CGO_ENABLED=0" fi name=cow-$3-$version echo "building $name" - GOOS=$1 GOARCH=$2 go build || exit 1 + eval $cgo $goos $goarch $goarm go build || exit 1 if [[ $1 == "windows" ]]; then mv cow.exe script pushd script @@ -29,7 +37,7 @@ build() { rm -f cow.exe sample-rc.txt mv $name.zip ../bin/ popd - else + else mv cow bin/$name gzip -f bin/$name fi @@ -39,6 +47,6 @@ build darwin amd64 mac64 build linux amd64 linux64 build linux 386 linux32 build linux arm linux-armv6 +build linux arm linux-armv5 build windows amd64 win64 build windows 386 win32 - diff --git a/sitestat.go b/sitestat.go index 4b63531b..c44c704d 100644 --- a/sitestat.go +++ b/sitestat.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "github.com/cyfdecyf/bufio" - "io" "io/ioutil" "math/rand" "os" @@ -23,7 +22,7 @@ func init() { // judging whether a site is blocked or not is more reliable. const ( - directDelta = 20 + directDelta = 15 blockedDelta = 10 maxCnt = 100 // no protect to update visit cnt, smaller value is unlikely to overflow userCnt = -1 // this represents user specified host or domain @@ -442,20 +441,17 @@ func loadSiteList(fpath string) (lst []string, err error) { } defer f.Close() - fr := bufio.NewReader(f) + scanner := bufio.NewScanner(f) lst = make([]string, 0) - var site string - for { - site, err = ReadLine(fr) - if err == io.EOF { - return lst, nil - } else if err != nil { - errl.Printf("Error reading domain list %s: %v\n", fpath, err) - return - } + for scanner.Scan() { + site := strings.TrimSpace(scanner.Text()) if site == "" { continue } - lst = append(lst, strings.TrimSpace(site)) + lst = append(lst, site) + } + if scanner.Err() != nil { + errl.Printf("Error reading domain list %s: %v\n", fpath, scanner.Err()) } + return lst, scanner.Err() } diff --git a/util.go b/util.go index be8bcea3..bc4b5ea4 100644 --- a/util.go +++ b/util.go @@ -37,37 +37,6 @@ func (n notification) hasNotified() bool { } } -// ReadLine read till '\n' is found or encounter error. The returned line does -// not include ending '\r' and '\n'. If returns err != nil if and only if -// len(line) == 0. -func ReadLine(r *bufio.Reader) (string, error) { - l, err := ReadLineSlice(r) - return string(l), err -} - -// ReadLineBytes read till '\n' is found or encounter error. The returned line -// does not include ending '\r\n' or '\n'. Returns err != nil if and only if -// len(line) == 0. Note the returned byte should not be used for append and -// maybe overwritten by next I/O operation. Copied code of readLineSlice from -// $GOROOT/src/pkg/net/textproto/reader.go -func ReadLineSlice(r *bufio.Reader) (line []byte, err error) { - for { - l, more, err := r.ReadLine() - if err != nil { - return nil, err - } - // Avoid the copy if the first call produced a full line. - if line == nil && !more { - return l, nil - } - line = append(line, l...) - if !more { - break - } - } - return line, nil -} - func ASCIIToUpperInplace(b []byte) { for i := 0; i < len(b); i++ { if 97 <= b[i] && b[i] <= 122 { @@ -124,9 +93,6 @@ func IsSpace(b byte) bool { } func TrimSpace(s []byte) []byte { - if len(s) == 0 { - return s - } st := 0 end := len(s) - 1 for ; st < len(s) && IsSpace(s[st]); st++ { @@ -139,6 +105,13 @@ func TrimSpace(s []byte) []byte { return s[st : end+1] } +func TrimTrailingSpace(s []byte) []byte { + end := len(s) - 1 + for ; end >= 0 && IsSpace(s[end]); end-- { + } + return s[:end+1] +} + // FieldsN is simliar with bytes.Fields, but only consider space and '\t' as // space, and will include all content in the final slice with ending white // space characters trimmed. bytes.Split can't split on both space and '\t', @@ -195,11 +168,11 @@ func ParseIntFromBytes(b []byte, base int) (n int64, err error) { // Refer to: http://code.google.com/p/go/issues/detail?id=2632 // That's why I created this function. if base != 10 && base != 16 { - err = errors.New(fmt.Sprintf("Invalid base: %d\n", base)) + err = errors.New(fmt.Sprintf("invalid base: %d", base)) return } if len(b) == 0 { - err = errors.New("Parse int from empty string") + err = errors.New("parse int from empty bytes") return } @@ -215,12 +188,12 @@ func ParseIntFromBytes(b []byte, base int) (n int64, err error) { v := digitTbl[d] if v == -1 { n = 0 - err = errors.New(fmt.Sprintf("Invalid number: %s", b)) + err = errors.New(fmt.Sprintf("invalid number: %s", b)) return } if int(v) >= base { n = 0 - err = errors.New(fmt.Sprintf("Invalid base %d number: %s", base, b)) + err = errors.New(fmt.Sprintf("invalid base %d number: %s", base, b)) return } n *= int64(base) @@ -260,22 +233,6 @@ func isDirExists(path string) (bool, error) { return false, err } -// Get host IP address -func hostIP() (addrs []string, err error) { - name, err := os.Hostname() - if err != nil { - fmt.Printf("Error get host name: %v\n", err) - return - } - - addrs, err = net.LookupHost(name) - if err != nil { - fmt.Printf("Error getting host IP address: %v\n", err) - return - } - return -} - func getUserHomeDir() string { home := os.Getenv("HOME") if home == "" { @@ -293,7 +250,7 @@ func expandTilde(pth string) string { } // copyN copys N bytes from src to dst, reading at most rdSize for each read. -// rdSize should be smaller than the buffer size of Reader. +// rdSize should <= buffer size of the buffered reader. // Returns any encountered error. func copyN(dst io.Writer, src *bufio.Reader, n, rdSize int) (err error) { // Most of the copy is copied from io.Copy diff --git a/util_test.go b/util_test.go index eb71ed51..3946f4d3 100644 --- a/util_test.go +++ b/util_test.go @@ -4,43 +4,10 @@ import ( "bytes" "errors" "github.com/cyfdecyf/bufio" - "io" "strings" "testing" ) -func TestReadLine(t *testing.T) { - testData := []struct { - text string - lines []string - }{ - {"one\ntwo", []string{"one", "two"}}, - {"three\r\nfour\n", []string{"three", "four"}}, - {"five\r\nsix\r\n", []string{"five", "six"}}, - {"seven", []string{"seven"}}, - {"eight\n", []string{"eight"}}, - {"\r\n", []string{""}}, - {"\n", []string{""}}, - } - for _, td := range testData { - raw := strings.NewReader(td.text) - rd := bufio.NewReader(raw) - for i, line := range td.lines { - l, err := ReadLine(rd) - if err != nil { - t.Fatalf("%d read error %v got: %s\ntext: %s\n", i+1, err, l, td.text) - } - if line != l { - t.Fatalf("%d read got: %s (%d) should be: %s (%d)\n", i+1, l, len(l), line, len(line)) - } - } - _, err := ReadLine(rd) - if err != io.EOF { - t.Error("ReadLine past end should return EOF") - } - } -} - func TestASCIIToUpper(t *testing.T) { testData := []struct { raw []byte @@ -132,6 +99,26 @@ func TestTrimSpace(t *testing.T) { } } +func TestTrimTrailingSpace(t *testing.T) { + testData := []struct { + old string + trimed string + }{ + {"hello", "hello"}, + {" hello", " hello"}, + {" hello\r\n ", " hello"}, + {" hello \t ", " hello"}, + {"", ""}, + {"\r\n", ""}, + } + for _, td := range testData { + trimed := string(TrimTrailingSpace([]byte(td.old))) + if trimed != td.trimed { + t.Errorf("%s trimmed to %s, should be %s\n", td.old, trimed, td.trimed) + } + } +} + func TestFieldsN(t *testing.T) { testData := []struct { raw string