Skip to content

Commit

Permalink
Merge pull request #5 from Recidiviz/colincadams/r/string-func-null
Browse files Browse the repository at this point in the history
Harden string functions when NULL is passed
  • Loading branch information
colincadams authored and ohaibbq committed Jan 16, 2024
1 parent bb77650 commit a7853c3
Show file tree
Hide file tree
Showing 3 changed files with 274 additions and 88 deletions.
100 changes: 86 additions & 14 deletions internal/function_bind.go
Expand Up @@ -798,16 +798,22 @@ func bindCollate(args ...Value) (Value, error) {
}

func bindConcat(args ...Value) (Value, error) {
if len(args) < 2 {
if len(args) < 1 {
return nil, fmt.Errorf("CONCAT: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
return CONCAT(args...)
}

func bindContainsSubstr(args ...Value) (Value, error) {
if args[1] == nil {
return nil, fmt.Errorf("CONTAINS_SUBSTR: search literal must be not null")
}
if existsNull(args) {
return nil, nil
}
search, err := args[1].ToString()
if err != nil {
return nil, err
Expand All @@ -819,21 +825,24 @@ func bindEndsWith(args ...Value) (Value, error) {
if len(args) != 2 {
return nil, fmt.Errorf("ENDS_WITH: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
return ENDS_WITH(args[0], args[1])
}

func bindFormat(args ...Value) (Value, error) {
if len(args) == 0 {
return nil, fmt.Errorf("FORMAT: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
format, err := args[0].ToString()
if err != nil {
return nil, err
}
if len(args) > 1 {
if args[1] == nil {
return nil, nil
}
return FORMAT(format, args[1:]...)
}
return FORMAT(format)
Expand All @@ -843,6 +852,9 @@ func bindFromBase32(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("FROM_BASE32: invalid argument num %d", len(args))
}
if args[0] == nil {
return nil, nil
}
v, err := args[0].ToString()
if err != nil {
return nil, err
Expand All @@ -854,6 +866,9 @@ func bindFromBase64(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("FROM_BASE64: invalid argument num %d", len(args))
}
if args[0] == nil {
return nil, nil
}
v, err := args[0].ToString()
if err != nil {
return nil, err
Expand All @@ -865,6 +880,9 @@ func bindFromHex(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("FROM_HEX: invalid argument num %d", len(args))
}
if args[0] == nil {
return nil, nil
}
v, err := args[0].ToString()
if err != nil {
return nil, err
Expand All @@ -876,14 +894,11 @@ func bindInitcap(args ...Value) (Value, error) {
if len(args) != 1 && len(args) != 2 {
return nil, fmt.Errorf("INITCAP: invalid argument num %d", len(args))
}
if args[0] == nil {
if existsNull(args) {
return nil, nil
}
var delimiters []rune
if len(args) == 2 {
if args[1] == nil {
return nil, nil
}
v, err := args[1].ToString()
if err != nil {
return nil, err
Expand All @@ -904,10 +919,7 @@ func bindInstr(args ...Value) (Value, error) {
if len(args) != 2 && len(args) != 3 && len(args) != 4 {
return nil, fmt.Errorf("INSTR: invalid argument num %d", len(args))
}
if args[0] == nil {
return nil, nil
}
if args[1] == nil {
if existsNull(args) {
return nil, nil
}
var (
Expand Down Expand Up @@ -935,6 +947,9 @@ func bindLeft(args ...Value) (Value, error) {
if len(args) != 2 {
return nil, fmt.Errorf("LEFT: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
length, err := args[1].ToInt64()
if err != nil {
return nil, err
Expand All @@ -947,7 +962,7 @@ func bindLength(args ...Value) (Value, error) {
return nil, fmt.Errorf("LENGTH: invalid argument num %d", len(args))
}
if args[0] == nil {
return IntValue(0), nil
return nil, nil
}
return LENGTH(args[0])
}
Expand All @@ -956,6 +971,9 @@ func bindLpad(args ...Value) (Value, error) {
if len(args) != 2 && len(args) != 3 {
return nil, fmt.Errorf("LPAD: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
var pattern Value
if len(args) == 3 {
pattern = args[2]
Expand All @@ -971,13 +989,19 @@ func bindLower(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("LOWER: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
return LOWER(args[0])
}

func bindLtrim(args ...Value) (Value, error) {
if len(args) != 1 && len(args) != 2 {
return nil, fmt.Errorf("LTRIM: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
cutset := " "
if len(args) == 2 {
v, err := args[1].ToString()
Expand All @@ -993,6 +1017,9 @@ func bindNormalize(args ...Value) (Value, error) {
if len(args) != 1 && len(args) != 2 {
return nil, fmt.Errorf("NORMALIZE: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
mode := "NFC"
if len(args) == 2 {
v, err := args[1].ToString()
Expand All @@ -1012,6 +1039,9 @@ func bindNormalizeAndCasefold(args ...Value) (Value, error) {
if len(args) != 1 && len(args) != 2 {
return nil, fmt.Errorf("NORMALIZE_AND_CASEFOLD: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
mode := "NFC"
if len(args) == 2 {
v, err := args[1].ToString()
Expand Down Expand Up @@ -1043,6 +1073,9 @@ func bindRegexpContains(args ...Value) (Value, error) {
}

func bindRegexpExtract(args ...Value) (Value, error) {
if existsNull(args) {
return nil, nil
}
regexp, err := args[1].ToString()
if err != nil {
return nil, err
Expand All @@ -1067,6 +1100,9 @@ func bindRegexpExtract(args ...Value) (Value, error) {
}

func bindRegexpExtractAll(args ...Value) (Value, error) {
if existsNull(args) {
return nil, nil
}
regexp, err := args[1].ToString()
if err != nil {
return nil, err
Expand All @@ -1075,6 +1111,9 @@ func bindRegexpExtractAll(args ...Value) (Value, error) {
}

func bindRegexpInstr(args ...Value) (Value, error) {
if existsNull(args) {
return nil, nil
}
var (
pos int64 = 1
occurrence int64 = 1
Expand Down Expand Up @@ -1105,6 +1144,9 @@ func bindRegexpInstr(args ...Value) (Value, error) {
}

func bindRegexpReplace(args ...Value) (Value, error) {
if existsNull(args) {
return nil, nil
}
return REGEXP_REPLACE(args[0], args[1], args[2])
}

Expand Down Expand Up @@ -1198,7 +1240,7 @@ func bindSoundex(args ...Value) (Value, error) {

func bindSplit(args ...Value) (Value, error) {
if existsNull(args) {
return nil, nil
return &ArrayValue{}, nil
}
var delim Value
if len(args) > 1 {
Expand All @@ -1211,20 +1253,29 @@ func bindStartsWith(args ...Value) (Value, error) {
if len(args) != 2 {
return nil, fmt.Errorf("STARTS_WITH: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
return STARTS_WITH(args[0], args[1])
}

func bindStrpos(args ...Value) (Value, error) {
if len(args) != 2 {
return nil, fmt.Errorf("STRPOS: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
return STRPOS(args[0], args[1])
}

func bindSubstr(args ...Value) (Value, error) {
if len(args) != 2 && len(args) != 3 {
return nil, fmt.Errorf("SUBSTR: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
pos, err := args[1].ToInt64()
if err != nil {
return nil, err
Expand All @@ -1244,6 +1295,9 @@ func bindToBase32(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("TO_BASE32: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
b, err := args[0].ToBytes()
if err != nil {
return nil, err
Expand All @@ -1255,6 +1309,9 @@ func bindToBase64(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("TO_BASE64: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
b, err := args[0].ToBytes()
if err != nil {
return nil, err
Expand All @@ -1266,13 +1323,19 @@ func bindToCodePoints(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("TO_CODE_POINTS: invalid argument num %d", len(args))
}
if args[0] == nil {
return &ArrayValue{}, nil
}
return TO_CODE_POINTS(args[0])
}

func bindToHex(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("TO_HEX: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
b, err := args[0].ToBytes()
if err != nil {
return nil, err
Expand All @@ -1284,13 +1347,19 @@ func bindTranslate(args ...Value) (Value, error) {
if len(args) != 3 {
return nil, fmt.Errorf("TRANSLATE: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
return TRANSLATE(args[0], args[1], args[2])
}

func bindTrim(args ...Value) (Value, error) {
if len(args) != 1 && len(args) != 2 {
return nil, fmt.Errorf("TRIM: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
if len(args) == 2 {
return TRIM(args[0], args[1])
}
Expand Down Expand Up @@ -1460,6 +1529,9 @@ func bindToJson(args ...Value) (Value, error) {
if len(args) != 1 && len(args) != 2 {
return nil, fmt.Errorf("TO_JSON: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
var stringifyWideNumbers bool
if len(args) == 2 {
b, err := args[1].ToBool()
Expand Down
5 changes: 4 additions & 1 deletion internal/function_string.go
Expand Up @@ -415,6 +415,9 @@ func LPAD(originalValue Value, returnLength int64, pattern Value) (Value, error)
}

func LOWER(v Value) (Value, error) {
if v == nil {
return nil, nil
}
switch v.(type) {
case StringValue:
s, err := v.ToString()
Expand Down Expand Up @@ -697,7 +700,7 @@ func REGEXP_REPLACE(value, exprValue, replacementValue Value) (Value, error) {
}
return BytesValue(re.ReplaceAll(v, []byte(normalizeReplacement(string(replacement))))), nil
}
return nil, fmt.Errorf("REGEXP_REPLACE: value must be STRING or BYTES")
return nil, fmt.Errorf("REGEXP_REPLACE: value must be STRING or BYTES, %s", value)
}

func REPLACE(originalValue, fromValue, toValue Value) (Value, error) {
Expand Down

0 comments on commit a7853c3

Please sign in to comment.