-
Notifications
You must be signed in to change notification settings - Fork 12
/
forward.go
114 lines (93 loc) · 2.49 KB
/
forward.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
package sshtun
import (
"context"
"fmt"
"io"
"net"
"golang.org/x/sync/errgroup"
)
// TunneledConnState represents the state of the final connections made through the tunnel.
type TunneledConnState struct {
// From is the address initating the connection.
From string
// Info holds a message with info on the state of the connection (useful for debug purposes).
Info string
// Error holds an error on the connection or nil if the connection is successful.
Error error
// Ready indicates if the connection is established.
Ready bool
}
func (s *TunneledConnState) String() string {
out := fmt.Sprintf("[%s] ", s.From)
if s.Info != "" {
out += s.Info
}
if s.Error != nil {
out += fmt.Sprintf("Error: %v", s.Error)
}
return out
}
func (tun *SSHTun) forward(localConn net.Conn) {
from := localConn.RemoteAddr().String()
tun.tunneledState(&TunneledConnState{
From: from,
Info: fmt.Sprintf("accepted %s connection", tun.local.Type()),
})
remoteConn, err := tun.sshClient.Dial(tun.remote.Type(), tun.remote.String())
if err != nil {
tun.tunneledState(&TunneledConnState{
From: from,
Error: fmt.Errorf("remote dial %s to %s failed: %w", tun.remote.Type(), tun.remote.String(), err),
})
localConn.Close()
return
}
connStr := fmt.Sprintf("%s -(%s)> %s -(ssh)> %s -(%s)> %s", from, tun.local.Type(), tun.local.String(),
tun.server.String(), tun.remote.Type(), tun.remote.String())
tun.tunneledState(&TunneledConnState{
From: from,
Info: fmt.Sprintf("connection stablished: %s", connStr),
Ready: true,
})
connCtx, connCancel := context.WithCancel(tun.ctx)
errGroup := &errgroup.Group{}
errGroup.Go(func() error {
defer connCancel()
_, err = io.Copy(remoteConn, localConn)
if err != nil {
return fmt.Errorf("failed copying bytes from remote to local: %w", err)
}
return nil
})
errGroup.Go(func() error {
defer connCancel()
_, err = io.Copy(localConn, remoteConn)
if err != nil {
return fmt.Errorf("failed copying bytes from local to remote: %w", err)
}
return nil
})
<-connCtx.Done()
localConn.Close()
remoteConn.Close()
err = errGroup.Wait()
select {
case <-tun.ctx.Done():
default:
if err != nil {
tun.tunneledState(&TunneledConnState{
From: from,
Error: err,
})
}
}
tun.tunneledState(&TunneledConnState{
From: from,
Info: fmt.Sprintf("connection closed: %s", connStr),
})
}
func (tun *SSHTun) tunneledState(state *TunneledConnState) {
if tun.tunneledConnState != nil {
tun.tunneledConnState(tun, state)
}
}