From 9acba6bf4975f428289a67a6f5f3f6c1354d6e30 Mon Sep 17 00:00:00 2001 From: "shaojin.wensj" Date: Sat, 10 Sep 2022 20:33:27 +0800 Subject: [PATCH] improved SQLUtils.acceptXXX --- .../java/com/alibaba/druid/sql/SQLUtils.java | 80 ++++++++++++++++ .../druid/bvt/sql/mysql/SQLUtilsTest.java | 93 +++++++++++++++++++ 2 files changed, 173 insertions(+) diff --git a/src/main/java/com/alibaba/druid/sql/SQLUtils.java b/src/main/java/com/alibaba/druid/sql/SQLUtils.java index c4ceaea885..1eb995af8a 100644 --- a/src/main/java/com/alibaba/druid/sql/SQLUtils.java +++ b/src/main/java/com/alibaba/druid/sql/SQLUtils.java @@ -38,6 +38,8 @@ import com.alibaba.druid.sql.dialect.mysql.ast.MySqlObject; import com.alibaba.druid.sql.dialect.mysql.ast.clause.MySqlSelectIntoStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement; +import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock; +import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlASTVisitorAdapter; import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlOutputVisitor; import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor; import com.alibaba.druid.sql.dialect.odps.ast.OdpsCreateTableStatement; @@ -45,6 +47,7 @@ import com.alibaba.druid.sql.dialect.odps.visitor.OdpsASTVisitorAdapter; import com.alibaba.druid.sql.dialect.odps.visitor.OdpsOutputVisitor; import com.alibaba.druid.sql.dialect.odps.visitor.OdpsSchemaStatVisitor; +import com.alibaba.druid.sql.dialect.oracle.visitor.OracleASTVisitorAdapter; import com.alibaba.druid.sql.dialect.oracle.visitor.OracleOutputVisitor; import com.alibaba.druid.sql.dialect.oracle.visitor.OracleSchemaStatVisitor; import com.alibaba.druid.sql.dialect.oracle.visitor.OracleToMySqlOutputVisitor; @@ -930,6 +933,25 @@ public boolean visit(OdpsSelectQueryBlock x) { } }; break; + case mysql: + visitor = new MySqlASTVisitorAdapter() { + @Override + public boolean visit(SQLSelectQueryBlock x) { + if (filter == null || filter.test(x)) { + consumer.accept(x); + } + return true; + } + + @Override + public boolean visit(MySqlSelectQueryBlock x) { + if (filter == null || filter.test(x)) { + consumer.accept(x); + } + return true; + } + }; + break; default: visitor = new SQLASTVisitorAdapter() { @Override @@ -975,6 +997,28 @@ public boolean visit(SQLAggregateExpr x) { } }; break; + case mysql: + visitor = new MySqlASTVisitorAdapter() { + @Override + public boolean visit(SQLAggregateExpr x) { + if (filter == null || filter.test(x)) { + consumer.accept(x); + } + return true; + } + }; + break; + case oracle: + visitor = new OracleASTVisitorAdapter() { + @Override + public boolean visit(SQLAggregateExpr x) { + if (filter == null || filter.test(x)) { + consumer.accept(x); + } + return true; + } + }; + break; default: visitor = new SQLASTVisitorAdapter() { @Override @@ -1027,6 +1071,42 @@ public boolean visit(SQLAggregateExpr x) { } }; break; + case mysql: + visitor = new MySqlASTVisitorAdapter() { + @Override + public boolean visit(SQLMethodInvokeExpr x) { + if (filter == null || filter.test(x)) { + consumer.accept(x); + } + return true; + } + @Override + public boolean visit(SQLAggregateExpr x) { + if (filter == null || filter.test(x)) { + consumer.accept(x); + } + return true; + } + }; + break; + case oracle: + visitor = new OracleASTVisitorAdapter() { + @Override + public boolean visit(SQLMethodInvokeExpr x) { + if (filter == null || filter.test(x)) { + consumer.accept(x); + } + return true; + } + @Override + public boolean visit(SQLAggregateExpr x) { + if (filter == null || filter.test(x)) { + consumer.accept(x); + } + return true; + } + }; + break; default: visitor = new SQLASTVisitorAdapter() { @Override diff --git a/src/test/java/com/alibaba/druid/bvt/sql/mysql/SQLUtilsTest.java b/src/test/java/com/alibaba/druid/bvt/sql/mysql/SQLUtilsTest.java index 4c930da00e..13f34e394e 100644 --- a/src/test/java/com/alibaba/druid/bvt/sql/mysql/SQLUtilsTest.java +++ b/src/test/java/com/alibaba/druid/bvt/sql/mysql/SQLUtilsTest.java @@ -1,13 +1,18 @@ package com.alibaba.druid.bvt.sql.mysql; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; +import com.alibaba.druid.DbType; +import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr; import junit.framework.TestCase; import org.junit.Assert; import com.alibaba.druid.sql.SQLUtils; import com.alibaba.druid.util.JdbcConstants; +import org.junit.Test; public class SQLUtilsTest extends TestCase { public void test_format() throws Exception { @@ -57,4 +62,92 @@ public void test_format_3() throws Exception { "\t, NULL# and lottery_notice_issue>=2014062 order by lottery_notice_issue desc"; Assert.assertEquals(expected, formattedSql); } + + public void testAcceptFunctionTest() { + List functions = new ArrayList<>(); + SQLUtils.acceptFunction( + "select count(*) from t", + DbType.odps, + e -> functions.add(e), + e -> true + ); + assertEquals(1, functions.size()); + } + + public void testAcceptFunctionTest_1() { + List functions = new ArrayList<>(); + SQLUtils.acceptAggregateFunction( + "select count(*) from t", + DbType.odps, + e -> functions.add(e), + e -> true + ); + assertEquals(1, functions.size()); + } + + public void testAcceptFunctionTest_pg() { + List functions = new ArrayList<>(); + SQLUtils.acceptFunction( + "select count(*) from t", + DbType.postgresql, + e -> functions.add(e), + e -> true + ); + assertEquals(1, functions.size()); + } + + public void testAcceptFunctionTest_pg_1() { + List functions = new ArrayList<>(); + SQLUtils.acceptAggregateFunction( + "select count(*) from t", + DbType.postgresql, + e -> functions.add(e), + e -> true + ); + assertEquals(1, functions.size()); + } + + public void testAcceptFunctionTest_oracle() { + List functions = new ArrayList<>(); + SQLUtils.acceptFunction( + "select count(*) from t", + DbType.oracle, + e -> functions.add(e), + e -> true + ); + assertEquals(1, functions.size()); + } + + public void testAcceptFunctionTest_oracle_1() { + List functions = new ArrayList<>(); + SQLUtils.acceptAggregateFunction( + "select count(*) from t", + DbType.oracle, + e -> functions.add(e), + e -> true + ); + assertEquals(1, functions.size()); + } + + public void testAcceptFunctionTest_ck() { + List functions = new ArrayList<>(); + SQLUtils.acceptFunction( + "select count(*) from t", + DbType.clickhouse, + e -> functions.add(e), + e -> true + ); + assertEquals(1, functions.size()); + } + + public void testAcceptFunctionTest_ck_1() { + List functions = new ArrayList<>(); + SQLUtils.acceptAggregateFunction( + "select count(*) from t", + DbType.clickhouse, + e -> functions.add(e), + e -> true + ); + assertEquals(1, functions.size()); + } }