Skip to content

Commit

Permalink
THRIFT-5369: Use MaxMessageSize to check container sizes
Browse files Browse the repository at this point in the history
Client: go
  • Loading branch information
fishy authored and Jens-G committed Jun 12, 2021
1 parent 63e86ce commit 57e24ca
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 32 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Expand Up @@ -6,6 +6,10 @@

- [THRIFT-5383](https://issues.apache.org/jira/browse/THRIFT-5383) - THRIFT-5383 TJSONProtocol Java readString throws on bounds check

### Go

- [THRIFT-5369](https://issues.apache.org/jira/browse/THRIFT-5369) - No longer pre-allocating the whole container (map/set/list) in compiled go code to avoid huge allocations on malformed messages


## 0.14.1

Expand Down
19 changes: 6 additions & 13 deletions lib/go/thrift/binary_protocol.go
Expand Up @@ -23,7 +23,6 @@ import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
Expand Down Expand Up @@ -334,8 +333,6 @@ func (p *TBinaryProtocol) ReadFieldEnd(ctx context.Context) error {
return nil
}

var invalidDataLength = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Invalid data length"))

func (p *TBinaryProtocol) ReadMapBegin(ctx context.Context) (kType, vType TType, size int, err error) {
k, e := p.ReadByte(ctx)
if e != nil {
Expand All @@ -354,8 +351,8 @@ func (p *TBinaryProtocol) ReadMapBegin(ctx context.Context) (kType, vType TType,
err = NewTProtocolException(e)
return
}
if size32 < 0 {
err = invalidDataLength
err = checkSizeForProtocol(size32, p.cfg)
if err != nil {
return
}
size = int(size32)
Expand All @@ -378,8 +375,8 @@ func (p *TBinaryProtocol) ReadListBegin(ctx context.Context) (elemType TType, si
err = NewTProtocolException(e)
return
}
if size32 < 0 {
err = invalidDataLength
err = checkSizeForProtocol(size32, p.cfg)
if err != nil {
return
}
size = int(size32)
Expand All @@ -403,8 +400,8 @@ func (p *TBinaryProtocol) ReadSetBegin(ctx context.Context) (elemType TType, siz
err = NewTProtocolException(e)
return
}
if size32 < 0 {
err = invalidDataLength
err = checkSizeForProtocol(size32, p.cfg)
if err != nil {
return
}
size = int(size32)
Expand Down Expand Up @@ -466,10 +463,6 @@ func (p *TBinaryProtocol) ReadString(ctx context.Context) (value string, err err
if err != nil {
return
}
if size < 0 {
err = invalidDataLength
return
}
if size == 0 {
return "", nil
}
Expand Down
12 changes: 6 additions & 6 deletions lib/go/thrift/compact_protocol.go
Expand Up @@ -477,8 +477,8 @@ func (p *TCompactProtocol) ReadMapBegin(ctx context.Context) (keyType TType, val
err = NewTProtocolException(e)
return
}
if size32 < 0 {
err = invalidDataLength
err = checkSizeForProtocol(size32, p.cfg)
if err != nil {
return
}
size = int(size32)
Expand Down Expand Up @@ -513,12 +513,12 @@ func (p *TCompactProtocol) ReadListBegin(ctx context.Context) (elemType TType, s
err = NewTProtocolException(e)
return
}
if size2 < 0 {
err = invalidDataLength
return
}
size = int(size2)
}
err = checkSizeForProtocol(size32, p.cfg)
if err != nil {
return
}
elemType, e := p.getTType(tCompactType(size_and_type))
if e != nil {
err = NewTProtocolException(e)
Expand Down
21 changes: 16 additions & 5 deletions lib/go/thrift/json_protocol.go
Expand Up @@ -311,9 +311,13 @@ func (p *TJSONProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueT
}

// read size
iSize, e := p.ReadI64(ctx)
if e != nil {
return keyType, valueType, size, e
iSize, err := p.ReadI64(ctx)
if err != nil {
return keyType, valueType, size, err
}
err = checkSizeForProtocol(int32(iSize), p.cfg)
if err != nil {
return keyType, valueType, 0, err
}
size = int(iSize)

Expand Down Expand Up @@ -485,9 +489,16 @@ func (p *TJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error)
if err != nil {
return elemType, size, err
}
nSize, _, err2 := p.ParseI64()
nSize, _, err := p.ParseI64()
if err != nil {
return elemType, 0, err
}
err = checkSizeForProtocol(int32(nSize), p.cfg)
if err != nil {
return elemType, 0, err
}
size = int(nSize)
return elemType, size, err2
return elemType, size, nil
}

func (p *TJSONProtocol) readElemListBegin() (elemType TType, size int, e error) {
Expand Down
63 changes: 55 additions & 8 deletions lib/go/thrift/simple_json_protocol.go
Expand Up @@ -97,16 +97,27 @@ var errEmptyJSONContextStack = NewTProtocolExceptionWithType(INVALID_DATA, error
type TSimpleJSONProtocol struct {
trans TTransport

cfg *TConfiguration

parseContextStack jsonContextStack
dumpContext jsonContextStack

writer *bufio.Writer
reader *bufio.Reader
}

// Constructor
// Deprecated: Use NewTSimpleJSONProtocolConf instead.:
func NewTSimpleJSONProtocol(t TTransport) *TSimpleJSONProtocol {
v := &TSimpleJSONProtocol{trans: t,
return NewTSimpleJSONProtocolConf(t, &TConfiguration{
noPropagation: true,
})
}

func NewTSimpleJSONProtocolConf(t TTransport, conf *TConfiguration) *TSimpleJSONProtocol {
PropagateTConfiguration(t, conf)
v := &TSimpleJSONProtocol{
trans: t,
cfg: conf,
writer: bufio.NewWriter(t),
reader: bufio.NewReader(t),
}
Expand All @@ -116,14 +127,32 @@ func NewTSimpleJSONProtocol(t TTransport) *TSimpleJSONProtocol {
}

// Factory
type TSimpleJSONProtocolFactory struct{}
type TSimpleJSONProtocolFactory struct {
cfg *TConfiguration
}

func (p *TSimpleJSONProtocolFactory) GetProtocol(trans TTransport) TProtocol {
return NewTSimpleJSONProtocol(trans)
return NewTSimpleJSONProtocolConf(trans, p.cfg)
}

// SetTConfiguration implements TConfigurationSetter for propagation.
func (p *TSimpleJSONProtocolFactory) SetTConfiguration(conf *TConfiguration) {
p.cfg = conf
}

// Deprecated: Use NewTSimpleJSONProtocolFactoryConf instead.
func NewTSimpleJSONProtocolFactory() *TSimpleJSONProtocolFactory {
return &TSimpleJSONProtocolFactory{}
return &TSimpleJSONProtocolFactory{
cfg: &TConfiguration{
noPropagation: true,
},
}
}

func NewTSimpleJSONProtocolFactoryConf(conf *TConfiguration) *TSimpleJSONProtocolFactory {
return &TSimpleJSONProtocolFactory{
cfg: conf,
}
}

var (
Expand Down Expand Up @@ -399,6 +428,13 @@ func (p *TSimpleJSONProtocol) ReadMapBegin(ctx context.Context) (keyType TType,

// read size
iSize, err := p.ReadI64(ctx)
if err != nil {
return keyType, valueType, 0, err
}
err = checkSizeForProtocol(int32(size), p.cfg)
if err != nil {
return keyType, valueType, 0, err
}
size = int(iSize)
return keyType, valueType, size, err
}
Expand Down Expand Up @@ -1070,9 +1106,16 @@ func (p *TSimpleJSONProtocol) ParseElemListBegin() (elemType TType, size int, e
if err != nil {
return elemType, size, err
}
nSize, _, err2 := p.ParseI64()
nSize, _, err := p.ParseI64()
if err != nil {
return elemType, 0, err
}
err = checkSizeForProtocol(int32(nSize), p.cfg)
if err != nil {
return elemType, 0, err
}
size = int(nSize)
return elemType, size, err2
return elemType, size, nil
}

func (p *TSimpleJSONProtocol) ParseListEnd() error {
Expand Down Expand Up @@ -1368,6 +1411,10 @@ func (p *TSimpleJSONProtocol) write(b []byte) (int, error) {
// SetTConfiguration implements TConfigurationSetter for propagation.
func (p *TSimpleJSONProtocol) SetTConfiguration(conf *TConfiguration) {
PropagateTConfiguration(p.trans, conf)
p.cfg = conf
}

var _ TConfigurationSetter = (*TSimpleJSONProtocol)(nil)
var (
_ TConfigurationSetter = (*TSimpleJSONProtocol)(nil)
_ TConfigurationSetter = (*TSimpleJSONProtocolFactory)(nil)
)

0 comments on commit 57e24ca

Please sign in to comment.