Skip to content

Commit

Permalink
Add connection retry with MaxConnectionAttempts (#16)
Browse files Browse the repository at this point in the history
The application using this library hangs (never exits) if the SSH tunnel is being
used by a lots of goroutines and an error occurs in the forward() method when a connection
is being made. Connection attempts seem to intermittently fail, and this somehow leads to the code
never exiting when complete.

I found that the connection attempt would succeed after 1 or 2 retries, then the application would later
exit like normal. This PR adds a optional retry mechanism. It must be enabled with MaxConnectionAttempts.

Fixed #15
  • Loading branch information
davidalexisnyt committed Feb 14, 2023
1 parent 6539d4e commit 27700fc
Showing 1 changed file with 38 additions and 14 deletions.
52 changes: 38 additions & 14 deletions ssh_tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ type logger interface {
}

type SSHTunnel struct {
Local *Endpoint
Server *Endpoint
Remote *Endpoint
Config *ssh.ClientConfig
Log logger
Conns []net.Conn
SvrConns []*ssh.Client
isOpen bool
close chan interface{}
Local *Endpoint
Server *Endpoint
Remote *Endpoint
Config *ssh.ClientConfig
Log logger
Conns []net.Conn
SvrConns []*ssh.Client
MaxConnectionAttempts int
isOpen bool
close chan interface{}
}

func (tunnel *SSHTunnel) logf(fmt string, args ...interface{}) {
Expand All @@ -45,6 +46,14 @@ func (tunnel *SSHTunnel) Start() error {
tunnel.isOpen = true
tunnel.Local.Port = listener.Addr().(*net.TCPAddr).Port

// Ensure that MaxConnectionAttempts is at least 1. This check is done here
// since the library user can set the value at any point before Start() is called,
// and this check protects against the case where the programmer set MaxConnectionAttempts
// to 0 for some reason.
if tunnel.MaxConnectionAttempts <= 0 {
tunnel.MaxConnectionAttempts = 1
}

for {
if !tunnel.isOpen {
break
Expand Down Expand Up @@ -90,14 +99,29 @@ func (tunnel *SSHTunnel) Start() error {
}

func (tunnel *SSHTunnel) forward(localConn net.Conn) {
serverConn, err := ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config)
if err != nil {
tunnel.logf("server dial error: %s", err)
return
var (
serverConn *ssh.Client
err error
attemptsLeft int = tunnel.MaxConnectionAttempts
)

for {
serverConn, err = ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config)
if err != nil {
attemptsLeft--

if attemptsLeft <= 0 {
tunnel.logf("server dial error: %v: exceeded %d attempts", err, tunnel.MaxConnectionAttempts)
return
}
} else {
break
}
}

tunnel.logf("connected to %s (1 of 2)\n", tunnel.Server.String())
tunnel.SvrConns = append(tunnel.SvrConns, serverConn)

remoteConn, err := serverConn.Dial("tcp", tunnel.Remote.String())
if err != nil {
tunnel.logf("remote dial error: %s", err)
Expand Down

0 comments on commit 27700fc

Please sign in to comment.