Skip to content

Commit

Permalink
Merge pull request #5 from Recidiviz/colincadams/r/string-func-null
Browse files Browse the repository at this point in the history
Harden string functions when NULL is passed
  • Loading branch information
colincadams committed Jun 7, 2023
2 parents a9a764f + c8bdbe7 commit b2934d9
Show file tree
Hide file tree
Showing 3 changed files with 274 additions and 88 deletions.
100 changes: 86 additions & 14 deletions internal/function_bind.go
Expand Up @@ -775,16 +775,22 @@ func bindCollate(args ...Value) (Value, error) {
}

func bindConcat(args ...Value) (Value, error) {
if len(args) < 2 {
if len(args) < 1 {
return nil, fmt.Errorf("CONCAT: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
return CONCAT(args...)
}

func bindContainsSubstr(args ...Value) (Value, error) {
if args[1] == nil {
return nil, fmt.Errorf("CONTAINS_SUBSTR: search literal must be not null")
}
if existsNull(args) {
return nil, nil
}
search, err := args[1].ToString()
if err != nil {
return nil, err
Expand All @@ -796,21 +802,24 @@ func bindEndsWith(args ...Value) (Value, error) {
if len(args) != 2 {
return nil, fmt.Errorf("ENDS_WITH: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
return ENDS_WITH(args[0], args[1])
}

func bindFormat(args ...Value) (Value, error) {
if len(args) == 0 {
return nil, fmt.Errorf("FORMAT: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
format, err := args[0].ToString()
if err != nil {
return nil, err
}
if len(args) > 1 {
if args[1] == nil {
return nil, nil
}
return FORMAT(format, args[1:]...)
}
return FORMAT(format)
Expand All @@ -820,6 +829,9 @@ func bindFromBase32(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("FROM_BASE32: invalid argument num %d", len(args))
}
if args[0] == nil {
return nil, nil
}
v, err := args[0].ToString()
if err != nil {
return nil, err
Expand All @@ -831,6 +843,9 @@ func bindFromBase64(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("FROM_BASE64: invalid argument num %d", len(args))
}
if args[0] == nil {
return nil, nil
}
v, err := args[0].ToString()
if err != nil {
return nil, err
Expand All @@ -842,6 +857,9 @@ func bindFromHex(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("FROM_HEX: invalid argument num %d", len(args))
}
if args[0] == nil {
return nil, nil
}
v, err := args[0].ToString()
if err != nil {
return nil, err
Expand All @@ -853,14 +871,11 @@ func bindInitcap(args ...Value) (Value, error) {
if len(args) != 1 && len(args) != 2 {
return nil, fmt.Errorf("INITCAP: invalid argument num %d", len(args))
}
if args[0] == nil {
if existsNull(args) {
return nil, nil
}
var delimiters []rune
if len(args) == 2 {
if args[1] == nil {
return nil, nil
}
v, err := args[1].ToString()
if err != nil {
return nil, err
Expand All @@ -881,10 +896,7 @@ func bindInstr(args ...Value) (Value, error) {
if len(args) != 2 && len(args) != 3 && len(args) != 4 {
return nil, fmt.Errorf("INSTR: invalid argument num %d", len(args))
}
if args[0] == nil {
return nil, nil
}
if args[1] == nil {
if existsNull(args) {
return nil, nil
}
var (
Expand Down Expand Up @@ -912,6 +924,9 @@ func bindLeft(args ...Value) (Value, error) {
if len(args) != 2 {
return nil, fmt.Errorf("LEFT: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
length, err := args[1].ToInt64()
if err != nil {
return nil, err
Expand All @@ -924,7 +939,7 @@ func bindLength(args ...Value) (Value, error) {
return nil, fmt.Errorf("LENGTH: invalid argument num %d", len(args))
}
if args[0] == nil {
return IntValue(0), nil
return nil, nil
}
return LENGTH(args[0])
}
Expand All @@ -933,6 +948,9 @@ func bindLpad(args ...Value) (Value, error) {
if len(args) != 2 && len(args) != 3 {
return nil, fmt.Errorf("LPAD: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
var pattern Value
if len(args) == 3 {
pattern = args[2]
Expand All @@ -948,13 +966,19 @@ func bindLower(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("LOWER: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
return LOWER(args[0])
}

func bindLtrim(args ...Value) (Value, error) {
if len(args) != 1 && len(args) != 2 {
return nil, fmt.Errorf("LTRIM: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
cutset := " "
if len(args) == 2 {
v, err := args[1].ToString()
Expand All @@ -970,6 +994,9 @@ func bindNormalize(args ...Value) (Value, error) {
if len(args) != 1 && len(args) != 2 {
return nil, fmt.Errorf("NORMALIZE: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
mode := "NFC"
if len(args) == 2 {
v, err := args[1].ToString()
Expand All @@ -989,6 +1016,9 @@ func bindNormalizeAndCasefold(args ...Value) (Value, error) {
if len(args) != 1 && len(args) != 2 {
return nil, fmt.Errorf("NORMALIZE_AND_CASEFOLD: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
mode := "NFC"
if len(args) == 2 {
v, err := args[1].ToString()
Expand Down Expand Up @@ -1020,6 +1050,9 @@ func bindRegexpContains(args ...Value) (Value, error) {
}

func bindRegexpExtract(args ...Value) (Value, error) {
if existsNull(args) {
return nil, nil
}
regexp, err := args[1].ToString()
if err != nil {
return nil, err
Expand All @@ -1044,6 +1077,9 @@ func bindRegexpExtract(args ...Value) (Value, error) {
}

func bindRegexpExtractAll(args ...Value) (Value, error) {
if existsNull(args) {
return nil, nil
}
regexp, err := args[1].ToString()
if err != nil {
return nil, err
Expand All @@ -1052,6 +1088,9 @@ func bindRegexpExtractAll(args ...Value) (Value, error) {
}

func bindRegexpInstr(args ...Value) (Value, error) {
if existsNull(args) {
return nil, nil
}
var (
pos int64 = 1
occurrence int64 = 1
Expand Down Expand Up @@ -1082,6 +1121,9 @@ func bindRegexpInstr(args ...Value) (Value, error) {
}

func bindRegexpReplace(args ...Value) (Value, error) {
if existsNull(args) {
return nil, nil
}
return REGEXP_REPLACE(args[0], args[1], args[2])
}

Expand Down Expand Up @@ -1175,7 +1217,7 @@ func bindSoundex(args ...Value) (Value, error) {

func bindSplit(args ...Value) (Value, error) {
if existsNull(args) {
return nil, nil
return &ArrayValue{}, nil
}
var delim Value
if len(args) > 1 {
Expand All @@ -1188,20 +1230,29 @@ func bindStartsWith(args ...Value) (Value, error) {
if len(args) != 2 {
return nil, fmt.Errorf("STARTS_WITH: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
return STARTS_WITH(args[0], args[1])
}

func bindStrpos(args ...Value) (Value, error) {
if len(args) != 2 {
return nil, fmt.Errorf("STRPOS: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
return STRPOS(args[0], args[1])
}

func bindSubstr(args ...Value) (Value, error) {
if len(args) != 2 && len(args) != 3 {
return nil, fmt.Errorf("SUBSTR: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
pos, err := args[1].ToInt64()
if err != nil {
return nil, err
Expand All @@ -1221,6 +1272,9 @@ func bindToBase32(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("TO_BASE32: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
b, err := args[0].ToBytes()
if err != nil {
return nil, err
Expand All @@ -1232,6 +1286,9 @@ func bindToBase64(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("TO_BASE64: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
b, err := args[0].ToBytes()
if err != nil {
return nil, err
Expand All @@ -1243,13 +1300,19 @@ func bindToCodePoints(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("TO_CODE_POINTS: invalid argument num %d", len(args))
}
if args[0] == nil {
return &ArrayValue{}, nil
}
return TO_CODE_POINTS(args[0])
}

func bindToHex(args ...Value) (Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("TO_HEX: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
b, err := args[0].ToBytes()
if err != nil {
return nil, err
Expand All @@ -1261,13 +1324,19 @@ func bindTranslate(args ...Value) (Value, error) {
if len(args) != 3 {
return nil, fmt.Errorf("TRANSLATE: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
return TRANSLATE(args[0], args[1], args[2])
}

func bindTrim(args ...Value) (Value, error) {
if len(args) != 1 && len(args) != 2 {
return nil, fmt.Errorf("TRIM: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
if len(args) == 2 {
return TRIM(args[0], args[1])
}
Expand Down Expand Up @@ -1437,6 +1506,9 @@ func bindToJson(args ...Value) (Value, error) {
if len(args) != 1 && len(args) != 2 {
return nil, fmt.Errorf("TO_JSON: invalid argument num %d", len(args))
}
if existsNull(args) {
return nil, nil
}
var stringifyWideNumbers bool
if len(args) == 2 {
b, err := args[1].ToBool()
Expand Down
5 changes: 4 additions & 1 deletion internal/function_string.go
Expand Up @@ -415,6 +415,9 @@ func LPAD(originalValue Value, returnLength int64, pattern Value) (Value, error)
}

func LOWER(v Value) (Value, error) {
if v == nil {
return nil, nil
}
switch v.(type) {
case StringValue:
s, err := v.ToString()
Expand Down Expand Up @@ -697,7 +700,7 @@ func REGEXP_REPLACE(value, exprValue, replacementValue Value) (Value, error) {
}
return BytesValue(re.ReplaceAll(v, []byte(normalizeReplacement(string(replacement))))), nil
}
return nil, fmt.Errorf("REGEXP_REPLACE: value must be STRING or BYTES")
return nil, fmt.Errorf("REGEXP_REPLACE: value must be STRING or BYTES, %s", value)
}

func REPLACE(originalValue, fromValue, toValue Value) (Value, error) {
Expand Down

0 comments on commit b2934d9

Please sign in to comment.