Skip to content

Commit

Permalink
Improve listener and dial failure handling (#18)
Browse files Browse the repository at this point in the history
Add a Listen/Serve pattern, and wrap Start around that, to allow initializing the listener prior to starting the service
flag to immediately close connections that fail to dial.
  • Loading branch information
hownowstephen committed Sep 2, 2023
1 parent 3d56ada commit 93cd1e6
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 13 deletions.
5 changes: 3 additions & 2 deletions private_key_file.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package sshtunnel

import (
"os"

"golang.org/x/crypto/ssh"
"io/ioutil"
)

func PrivateKeyFile(file string) ssh.AuthMethod {
buffer, err := ioutil.ReadFile(file)
buffer, err := os.ReadFile(file)
if err != nil {
return nil
}
Expand Down
41 changes: 30 additions & 11 deletions ssh_tunnel.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sshtunnel

import (
"errors"
"io"
"net"

Expand Down Expand Up @@ -38,11 +39,23 @@ func newConnectionWaiter(listener net.Listener, c chan net.Conn) {
c <- conn
}

func (tunnel *SSHTunnel) Start() error {
listener, err := net.Listen("tcp", tunnel.Local.String())
func (t *SSHTunnel) Listen() (net.Listener, error) {
return net.Listen("tcp", t.Local.String())
}

func (t *SSHTunnel) Start() error {
listener, err := t.Listen()
if err != nil {
t.logf("listen error: %s", err)
return err
}
defer listener.Close()

return t.Serve(listener)
}

func (tunnel *SSHTunnel) Serve(listener net.Listener) error {

tunnel.isOpen = true
tunnel.Local.Port = listener.Addr().(*net.TCPAddr).Port

Expand All @@ -61,15 +74,15 @@ func (tunnel *SSHTunnel) Start() error {

c := make(chan net.Conn)
go newConnectionWaiter(listener, c)
tunnel.logf("listening for new connections...")
tunnel.logf("listening for new connections on %s:%d...", tunnel.Local.Host, tunnel.Local.Port)

select {
case <-tunnel.close:
tunnel.logf("close signal received, closing...")
tunnel.isOpen = false
case conn := <-c:
tunnel.Conns = append(tunnel.Conns, conn)
tunnel.logf("accepted connection")
tunnel.logf("accepted connection from %s", conn.RemoteAddr().String())
go tunnel.forward(conn)
}
}
Expand All @@ -79,6 +92,10 @@ func (tunnel *SSHTunnel) Start() error {
tunnel.logf("closing the netConn (%d of %d)", i+1, total)
err := conn.Close()
if err != nil {
if errors.Is(err, net.ErrClosed) {
// no need to report on closed connections
continue
}
tunnel.logf(err.Error())
}
}
Expand All @@ -90,10 +107,7 @@ func (tunnel *SSHTunnel) Start() error {
tunnel.logf(err.Error())
}
}
err = listener.Close()
if err != nil {
return err
}

tunnel.logf("tunnel closed")
return nil
}
Expand All @@ -112,8 +126,16 @@ func (tunnel *SSHTunnel) forward(localConn net.Conn) {

if attemptsLeft <= 0 {
tunnel.logf("server dial error: %v: exceeded %d attempts", err, tunnel.MaxConnectionAttempts)

if err := localConn.Close(); err != nil {
tunnel.logf("failed to close local connection: %v", err)
return
}

tunnel.logf("dial failed, closing local connection: %v", err)
return
}
tunnel.logf("server dial error: %v: attempt %d/%d", err, tunnel.MaxConnectionAttempts-attemptsLeft, tunnel.MaxConnectionAttempts)
} else {
break
}
Expand All @@ -137,13 +159,10 @@ func (tunnel *SSHTunnel) forward(localConn net.Conn) {
}
go copyConn(localConn, remoteConn)
go copyConn(remoteConn, localConn)

return
}

func (tunnel *SSHTunnel) Close() {
tunnel.close <- struct{}{}
return
}

// NewSSHTunnel creates a new single-use tunnel. Supplying "0" for localport will use a random port.
Expand Down

0 comments on commit 93cd1e6

Please sign in to comment.