Skip to content

Commit

Permalink
Let the listener open more than 1 connection (#6)
Browse files Browse the repository at this point in the history
Fix the issue that was introduced in 1.1.0 where a listener would only accept a single connection. Instead track all the open connections in SSHTunnel.Conns (and SSHTunnel.SvrConns) and close all connections when SSHTunnel.Close() is called.

I think this addresses the issue #5
  • Loading branch information
isaaguilar committed Jul 25, 2020
1 parent 52c7403 commit 8447f1e
Showing 1 changed file with 59 additions and 24 deletions.
83 changes: 59 additions & 24 deletions ssh_tunnel.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
package sshtunnel

import (
"golang.org/x/crypto/ssh"
"io"
"log"
"net"
"sync"

"golang.org/x/crypto/ssh"
)

type SSHTunnel struct {
Local *Endpoint
Server *Endpoint
Remote *Endpoint
Config *ssh.ClientConfig
Log *log.Logger
close chan interface{}
Local *Endpoint
Server *Endpoint
Remote *Endpoint
Config *ssh.ClientConfig
Log *log.Logger
Conns []net.Conn
SvrConns []*ssh.Client
isOpen bool
close chan interface{}
}

func (tunnel *SSHTunnel) logf(fmt string, args ...interface{}) {
Expand All @@ -23,32 +26,67 @@ func (tunnel *SSHTunnel) logf(fmt string, args ...interface{}) {
}
}

func newConnectionWaiter(listener net.Listener, c chan net.Conn) {
conn, err := listener.Accept()
if err != nil {
return
}
c <- conn
}

func (tunnel *SSHTunnel) Start() error {
listener, err := net.Listen("tcp", tunnel.Local.String())
if err != nil {
return err
}
tunnel.isOpen = true
tunnel.Local.Port = listener.Addr().(*net.TCPAddr).Port

for {
conn, err := listener.Accept()
if !tunnel.isOpen {
break
}

c := make(chan net.Conn)
go newConnectionWaiter(listener, c)
tunnel.logf("listening for new connections...")

select {
case <-tunnel.close:
tunnel.logf("close signal received, closing...")
tunnel.isOpen = false
case conn := <-c:
tunnel.Conns = append(tunnel.Conns, conn)
tunnel.logf("accepted connection")
go tunnel.forward(conn)
}
}
var total int
total = len(tunnel.Conns)
for i, conn := range tunnel.Conns {
tunnel.logf("closing the netConn (%d of %d)", i+1, total)
err := conn.Close()
if err != nil {
return err
tunnel.logf(err.Error())
}
}
total = len(tunnel.SvrConns)
for i, conn := range tunnel.SvrConns {
tunnel.logf("closing the serverConn (%d of %d)", i+1, total)
err := conn.Close()
if err != nil {
tunnel.logf(err.Error())
}
tunnel.logf("accepted connection")
var wg sync.WaitGroup
go tunnel.forward(conn, &wg)
wg.Wait()
tunnel.logf("tunnel closed")
break
}
err = listener.Close()
if err != nil {
return err
}
tunnel.logf("tunnel closed")
return nil
}

func (tunnel *SSHTunnel) forward(localConn net.Conn, wg *sync.WaitGroup) {
func (tunnel *SSHTunnel) forward(localConn net.Conn) {
serverConn, err := ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config)
if err != nil {
tunnel.logf("server dial error: %s", err)
Expand All @@ -60,6 +98,8 @@ func (tunnel *SSHTunnel) forward(localConn net.Conn, wg *sync.WaitGroup) {
tunnel.logf("remote dial error: %s", err)
return
}
tunnel.Conns = append(tunnel.Conns, remoteConn)
tunnel.SvrConns = append(tunnel.SvrConns, serverConn)
tunnel.logf("connected to %s (2 of 2)\n", tunnel.Remote.String())
copyConn := func(writer, reader net.Conn) {
_, err := io.Copy(writer, reader)
Expand All @@ -69,12 +109,7 @@ func (tunnel *SSHTunnel) forward(localConn net.Conn, wg *sync.WaitGroup) {
}
go copyConn(localConn, remoteConn)
go copyConn(remoteConn, localConn)
<-tunnel.close
tunnel.logf("close signal received, closing...")
_ = localConn.Close()
_ = serverConn.Close()
_ = remoteConn.Close()
wg.Done()

return
}

Expand All @@ -86,7 +121,7 @@ func (tunnel *SSHTunnel) Close() {
// NewSSHTunnel creates a new single-use tunnel. Supplying "0" for localport will use a random port.
func NewSSHTunnel(tunnel string, auth ssh.AuthMethod, destination string, localport string) *SSHTunnel {

localEndpoint := NewEndpoint("localhost:"+localport)
localEndpoint := NewEndpoint("localhost:" + localport)

server := NewEndpoint(tunnel)
if server.Port == 0 {
Expand Down

0 comments on commit 8447f1e

Please sign in to comment.