Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(presto): fix parsing and generating hash functions presto/trino #3459

Merged
merged 1 commit into from May 11, 2024

Conversation

viplazylmht
Copy link
Contributor

Fixes #3458

Affected functions:

  • MD5_DIGEST
  • MD5
  • SHA
  • SHA2

Affected functions:
  - MD5_DIGEST
  - MD5
  - SHA
  - SHA2
@@ -281,6 +281,9 @@ class Parser(parser.Parser):
"TO_UTF8": lambda args: exp.Encode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
"MD5": exp.MD5Digest.from_arg_list,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems repeated from clickhouse, can we add this into dialects and make it so clickhouse and presto both use these implementations

in dialects, you can create a dict

{ "MD5": ..., } and then import that dict here and clickhouse and put it in both places

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that is a good idea.

Do you mean I can implement like this? This way just fit for Parser, but the Generator are vary and depends on dialects (e.g. when generating exp.MD5)

diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index cc1d6208..c42fadbb 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -19,6 +19,11 @@ DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsD
 DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
 JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
 
+HASH_FUNCTION_PARSER: t.Dict[str, t.Callable[[t.List], exp.Func]] = {
+    "MD5": exp.MD5Digest.from_arg_list,
+    "SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)),
+    "SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)),
+}
 
 if t.TYPE_CHECKING:
     from sqlglot._typing import B, E, F
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 0d2cfc1a..54c8a2ed 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -25,6 +25,7 @@ from sqlglot.dialects.dialect import (
     timestrtotime_sql,
     ts_or_ds_add_cast,
     unit_to_var,
+    HASH_FUNCTION_PARSER,
 )
 from sqlglot.helper import seq_get, split_num_words
 from sqlglot.tokens import TokenType
@@ -313,6 +314,7 @@ class BigQuery(Dialect):
 
         FUNCTIONS = {
             **parser.Parser.FUNCTIONS,
+            **HASH_FUNCTION_PARSER,
             "DATE": _build_date,
             "DATE_ADD": build_date_delta_with_interval(exp.DateAdd),
             "DATE_SUB": build_date_delta_with_interval(exp.DateSub),
@@ -330,7 +332,6 @@ class BigQuery(Dialect):
             "JSON_EXTRACT_SCALAR": lambda args: exp.JSONExtractScalar(
                 this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string("$")
             ),
-            "MD5": exp.MD5Digest.from_arg_list,
             "TO_HEX": _build_to_hex,
             "PARSE_DATE": lambda args: build_formatted_time(exp.StrToDate, "bigquery")(
                 [seq_get(args, 1), seq_get(args, 0)]
@@ -344,8 +345,6 @@ class BigQuery(Dialect):
                 occurrence=seq_get(args, 3),
                 group=exp.Literal.number(1) if re.compile(args[1].name).groups == 1 else None,
             ),
-            "SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)),
-            "SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)),
             "SPLIT": lambda args: exp.Split(
                 # https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#split
                 this=seq_get(args, 0),
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 362ccde3..f4393ece 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -15,6 +15,7 @@ from sqlglot.dialects.dialect import (
     build_json_extract_path,
     rename_func,
     var_map_sql,
+    HASH_FUNCTION_PARSER,
 )
 from sqlglot.helper import is_int, seq_get
 from sqlglot.tokens import Token, TokenType
@@ -120,6 +121,7 @@ class ClickHouse(Dialect):
 
         FUNCTIONS = {
             **parser.Parser.FUNCTIONS,
+            **HASH_FUNCTION_PARSER,
             "ANY": exp.AnyValue.from_arg_list,
             "ARRAYSUM": exp.ArraySum.from_arg_list,
             "COUNTIF": _build_count_if,
@@ -146,9 +148,6 @@ class ClickHouse(Dialect):
             "TUPLE": exp.Struct.from_arg_list,
             "UNIQ": exp.ApproxDistinct.from_arg_list,
             "XOR": lambda args: exp.Xor(expressions=args),
-            "MD5": exp.MD5Digest.from_arg_list,
-            "SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)),
-            "SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)),
         }
 
         AGG_FUNCTIONS = {
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index b6e44cb9..f151dff5 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -27,6 +27,7 @@ from sqlglot.dialects.dialect import (
     timestrtotime_sql,
     ts_or_ds_add_cast,
     unit_to_str,
+    HASH_FUNCTION_PARSER,
 )
 from sqlglot.dialects.hive import Hive
 from sqlglot.dialects.mysql import MySQL
@@ -233,6 +234,7 @@ class Presto(Dialect):
 
         FUNCTIONS = {
             **parser.Parser.FUNCTIONS,
+            **HASH_FUNCTION_PARSER,
             "ARBITRARY": exp.AnyValue.from_arg_list,
             "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
             "APPROX_PERCENTILE": _build_approx_percentile,
@@ -281,9 +283,6 @@ class Presto(Dialect):
             "TO_UTF8": lambda args: exp.Encode(
                 this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
             ),
-            "MD5": exp.MD5Digest.from_arg_list,
-            "SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)),
-            "SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)),
         }
 
         FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()

viplazylmht added a commit to viplazylmht/sqlglot that referenced this pull request May 11, 2024
Copy link
Collaborator

@georgesittas georgesittas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I'd personally look into refactoring if we get ≥3 dialects repeating the same logic, but don't feel strongly about it.

@tobymao tobymao merged commit 58d5f2b into tobymao:main May 11, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Incorrect parsing and generating of SHA, SHA2, MD5, and MD5_Digest expressions in Presto/Trino dialect
3 participants