Skip to content

Commit 2b8115e

Browse files
committed
Copy UDP GSO support from tailscale
1 parent 06b4d4e commit 2b8115e

11 files changed

+1312
-819
lines changed

stack_system.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) {
244244
}
245245
}
246246
if len(writeBuffers) > 0 {
247-
err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom)
247+
_, err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom)
248248
if err != nil {
249249
s.logger.Trace(E.Cause(err, "batch write packet"))
250250
}

tun.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
package tun
22

33
import (
4-
"github.com/sagernet/sing/common/control"
54
"io"
65
"net"
76
"net/netip"
87
"runtime"
98
"strconv"
109
"strings"
1110

11+
"github.com/sagernet/sing/common/control"
1212
F "github.com/sagernet/sing/common/format"
1313
"github.com/sagernet/sing/common/logger"
1414
M "github.com/sagernet/sing/common/metadata"
@@ -39,7 +39,9 @@ type LinuxTUN interface {
3939
N.FrontHeadroom
4040
BatchSize() int
4141
BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error)
42-
BatchWrite(buffers [][]byte, offset int) error
42+
BatchWrite(buffers [][]byte, offset int) (n int, err error)
43+
DisableUDPGRO()
44+
DisableTCPGRO()
4345
TXChecksumOffload() bool
4446
}
4547

tun_darwin.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ package tun
33
import (
44
"errors"
55
"fmt"
6-
"github.com/sagernet/sing-tun/internal/gtcpip/header"
76
"net"
87
"net/netip"
98
"os"
109
"syscall"
1110
"unsafe"
1211

12+
"github.com/sagernet/sing-tun/internal/gtcpip/header"
1313
"github.com/sagernet/sing/common"
1414
"github.com/sagernet/sing/common/buf"
1515
"github.com/sagernet/sing/common/bufio"

tun_linux.go

Lines changed: 118 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package tun
22

33
import (
44
"errors"
5+
"fmt"
56
"math/rand"
67
"net"
78
"net/netip"
@@ -35,13 +36,15 @@ type NativeTun struct {
3536
interfaceCallback *list.Element[DefaultInterfaceUpdateCallback]
3637
options Options
3738
ruleIndex6 []int
38-
gsoEnabled bool
39-
gsoBuffer []byte
39+
readAccess sync.Mutex
40+
writeAccess sync.Mutex
41+
vnetHdr bool
42+
writeBuffer []byte
4043
gsoToWrite []int
41-
gsoReadAccess sync.Mutex
42-
tcpGROAccess sync.Mutex
43-
tcp4GROTable *tcpGROTable
44-
tcp6GROTable *tcpGROTable
44+
tcpGROTable *tcpGROTable
45+
udpGroAccess sync.Mutex
46+
udpGROTable *udpGROTable
47+
gro groDisablementFlags
4548
txChecksumOffload bool
4649
}
4750

@@ -81,20 +84,23 @@ func New(options Options) (Tun, error) {
8184
}
8285

8386
func (t *NativeTun) FrontHeadroom() int {
84-
if t.gsoEnabled {
87+
if t.vnetHdr {
8588
return virtioNetHdrLen
8689
}
8790
return 0
8891
}
8992

9093
func (t *NativeTun) Read(p []byte) (n int, err error) {
91-
if t.gsoEnabled {
92-
n, err = t.tunFile.Read(t.gsoBuffer)
94+
if t.vnetHdr {
95+
n, err = t.tunFile.Read(t.writeBuffer)
9396
if err != nil {
97+
if errors.Is(err, syscall.EBADFD) {
98+
err = os.ErrClosed
99+
}
94100
return
95101
}
96102
var sizes [1]int
97-
n, err = handleVirtioRead(t.gsoBuffer[:n], [][]byte{p}, sizes[:], 0)
103+
n, err = handleVirtioRead(t.writeBuffer[:n], [][]byte{p}, sizes[:], 0)
98104
if err != nil {
99105
return
100106
}
@@ -108,9 +114,50 @@ func (t *NativeTun) Read(p []byte) (n int, err error) {
108114
}
109115
}
110116

117+
// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
118+
// each buffer. It mutates sizes to reflect the size of each element of bufs,
119+
// and returns the number of packets read.
120+
func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
121+
var hdr virtioNetHdr
122+
err := hdr.decode(in)
123+
if err != nil {
124+
return 0, err
125+
}
126+
in = in[virtioNetHdrLen:]
127+
128+
options, err := hdr.toGSOOptions()
129+
if err != nil {
130+
return 0, err
131+
}
132+
133+
// Don't trust HdrLen from the kernel as it can be equal to the length
134+
// of the entire first packet when the kernel is handling it as part of a
135+
// FORWARD path. Instead, parse the transport header length and add it onto
136+
// CsumStart, which is synonymous for IP header length.
137+
if options.GSOType == GSOUDPL4 {
138+
options.HdrLen = options.CsumStart + 8
139+
} else if options.GSOType != GSONone {
140+
if len(in) <= int(options.CsumStart+12) {
141+
return 0, errors.New("packet is too short")
142+
}
143+
144+
tcpHLen := uint16(in[options.CsumStart+12] >> 4 * 4)
145+
if tcpHLen < 20 || tcpHLen > 60 {
146+
// A TCP header must be between 20 and 60 bytes in length.
147+
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
148+
}
149+
options.HdrLen = options.CsumStart + tcpHLen
150+
}
151+
152+
return GSOSplit(in, options, bufs, sizes, offset)
153+
}
154+
111155
func (t *NativeTun) Write(p []byte) (n int, err error) {
112-
if t.gsoEnabled {
113-
err = t.BatchWrite([][]byte{p}, virtioNetHdrLen)
156+
if t.vnetHdr {
157+
buffer := buf.Get(virtioNetHdrLen + len(p))
158+
copy(buffer[virtioNetHdrLen:], p)
159+
_, err = t.BatchWrite([][]byte{buffer}, virtioNetHdrLen)
160+
buf.Put(buffer)
114161
if err != nil {
115162
return
116163
}
@@ -121,7 +168,7 @@ func (t *NativeTun) Write(p []byte) (n int, err error) {
121168
}
122169

123170
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
124-
if t.gsoEnabled {
171+
if t.vnetHdr {
125172
n := buf.LenMulti(buffers)
126173
buffer := buf.NewSize(virtioNetHdrLen + n)
127174
buffer.Truncate(virtioNetHdrLen)
@@ -135,7 +182,7 @@ func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
135182
}
136183

137184
func (t *NativeTun) BatchSize() int {
138-
if !t.gsoEnabled {
185+
if !t.vnetHdr {
139186
return 1
140187
}
141188
/* // Not works on some devices: https://github.com/SagerNet/sing-box/issues/1605
@@ -147,36 +194,67 @@ func (t *NativeTun) BatchSize() int {
147194
return idealBatchSize
148195
}
149196

197+
// DisableUDPGRO disables UDP GRO if it is enabled. See the GRODevice interface
198+
// for cases where it should be called.
199+
func (t *NativeTun) DisableUDPGRO() {
200+
t.writeAccess.Lock()
201+
t.gro.disableUDPGRO()
202+
t.writeAccess.Unlock()
203+
}
204+
205+
// DisableTCPGRO disables TCP GRO if it is enabled. See the GRODevice interface
206+
// for cases where it should be called.
207+
func (t *NativeTun) DisableTCPGRO() {
208+
t.writeAccess.Lock()
209+
t.gro.disableTCPGRO()
210+
t.writeAccess.Unlock()
211+
}
212+
150213
func (t *NativeTun) BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) {
151-
t.gsoReadAccess.Lock()
152-
defer t.gsoReadAccess.Unlock()
153-
n, err = t.tunFile.Read(t.gsoBuffer)
214+
t.readAccess.Lock()
215+
defer t.readAccess.Unlock()
216+
n, err = t.tunFile.Read(t.writeBuffer)
154217
if err != nil {
155218
return
156219
}
157-
return handleVirtioRead(t.gsoBuffer[:n], buffers, readN, offset)
220+
return handleVirtioRead(t.writeBuffer[:n], buffers, readN, offset)
158221
}
159222

160-
func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) error {
161-
t.tcpGROAccess.Lock()
223+
func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) (int, error) {
224+
t.writeAccess.Lock()
162225
defer func() {
163-
t.tcp4GROTable.reset()
164-
t.tcp6GROTable.reset()
165-
t.tcpGROAccess.Unlock()
226+
t.tcpGROTable.reset()
227+
t.udpGROTable.reset()
228+
t.writeAccess.Unlock()
166229
}()
230+
var (
231+
errs error
232+
total int
233+
)
167234
t.gsoToWrite = t.gsoToWrite[:0]
168-
err := handleGRO(buffers, offset, t.tcp4GROTable, t.tcp6GROTable, &t.gsoToWrite)
169-
if err != nil {
170-
return err
235+
if t.vnetHdr {
236+
err := handleGRO(buffers, offset, t.tcpGROTable, t.udpGROTable, t.gro, &t.gsoToWrite)
237+
if err != nil {
238+
return 0, err
239+
}
240+
offset -= virtioNetHdrLen
241+
} else {
242+
for i := range buffers {
243+
t.gsoToWrite = append(t.gsoToWrite, i)
244+
}
171245
}
172-
offset -= virtioNetHdrLen
173-
for _, bufferIndex := range t.gsoToWrite {
174-
_, err = t.tunFile.Write(buffers[bufferIndex][offset:])
246+
for _, toWrite := range t.gsoToWrite {
247+
n, err := t.tunFile.Write(buffers[toWrite][offset:])
248+
if errors.Is(err, syscall.EBADFD) {
249+
return total, os.ErrClosed
250+
}
175251
if err != nil {
176-
return err
252+
errs = errors.Join(errs, err)
253+
} else {
254+
total += n
177255
}
178256
}
179-
return nil
257+
return total, errs
180258
}
181259

182260
var controlPath string
@@ -262,10 +340,14 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
262340
if err != nil {
263341
return err
264342
}
265-
t.gsoEnabled = true
266-
t.gsoBuffer = make([]byte, virtioNetHdrLen+int(gsoMaxSize))
267-
t.tcp4GROTable = newTCPGROTable()
268-
t.tcp6GROTable = newTCPGROTable()
343+
t.vnetHdr = true
344+
t.writeBuffer = make([]byte, virtioNetHdrLen+int(gsoMaxSize))
345+
t.tcpGROTable = newTCPGROTable()
346+
t.udpGROTable = newUDPGROTable()
347+
err = setUDPOffload(t.tunFd)
348+
if err != nil {
349+
t.gro.disableUDPGRO()
350+
}
269351
}
270352

271353
var rxChecksumOffload bool
@@ -280,7 +362,7 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
280362
if err != nil {
281363
return err
282364
}
283-
if err == nil && !txChecksumOffload {
365+
if !txChecksumOffload {
284366
err = setChecksumOffload(t.options.Name, unix.ETHTOOL_STXCSUM)
285367
if err != nil {
286368
return err

tun_linux_flags.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ import (
1212
"golang.org/x/sys/unix"
1313
)
1414

15+
const (
16+
// TODO: support TSO with ECN bits
17+
tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
18+
tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
19+
)
20+
1521
func checkVNETHDREnabled(fd int, name string) (bool, error) {
1622
ifr, err := unix.NewIfreq(name)
1723
if err != nil {
@@ -25,17 +31,17 @@ func checkVNETHDREnabled(fd int, name string) (bool, error) {
2531
}
2632

2733
func setTCPOffload(fd int) error {
28-
const (
29-
// TODO: support TSO with ECN bits
30-
tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
31-
)
32-
err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunOffloads)
34+
err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunTCPOffloads)
3335
if err != nil {
3436
return E.Cause(os.NewSyscallError("TUNSETOFFLOAD", err), "enable offload")
3537
}
3638
return nil
3739
}
3840

41+
func setUDPOffload(fd int) error {
42+
return unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads)
43+
}
44+
3945
type ifreqData struct {
4046
ifrName [unix.IFNAMSIZ]byte
4147
ifrData uintptr

tun_linux_gvisor.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
var _ GVisorTun = (*NativeTun)(nil)
1111

1212
func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) {
13-
if t.gsoEnabled {
13+
if t.vnetHdr {
1414
return fdbased.New(&fdbased.Options{
1515
FDs: []int{t.tunFd},
1616
MTU: t.options.MTU,

0 commit comments

Comments
 (0)