Skip to content

Commit

Permalink
增加 mariadb的 SET STATEMENT解析支持 #5861
Browse files Browse the repository at this point in the history
增加 mariadb的 SET STATEMENT解析支持 #5861
  • Loading branch information
lizongbo committed May 5, 2024
1 parent 05974bd commit 0467284
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.alibaba.druid.sql.ast.SQLCommentHint;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLName;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.SQLStatementImpl;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOperator;
Expand All @@ -37,6 +38,8 @@ public class SQLSetStatement extends SQLStatementImpl {

private boolean useSet;

SQLStatement maridbSetForStatement;

public SQLSetStatement() {
}

Expand Down Expand Up @@ -98,11 +101,22 @@ public void set(SQLExpr target, SQLExpr value) {
this.items.add(assignItem);
}

public SQLStatement getMaridbSetForStatement() {
return maridbSetForStatement;
}

public void setMaridbSetForStatement(SQLStatement maridbSetForStatement) {
this.maridbSetForStatement = maridbSetForStatement;
}

@Override
protected void accept0(SQLASTVisitor visitor) {
if (visitor.visit(this)) {
acceptChild(visitor, this.items);
acceptChild(visitor, this.hints);
if (maridbSetForStatement != null) {
maridbSetForStatement.accept(visitor);
}
}
visitor.endVisit(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.alibaba.druid.sql.repository.SchemaObject;
import com.alibaba.druid.sql.visitor.SQLASTOutputVisitor;
import com.alibaba.druid.util.FnvHash;
import com.alibaba.druid.util.JdbcUtils;
import com.alibaba.druid.util.StringUtils;

import java.util.ArrayList;
Expand Down Expand Up @@ -5191,9 +5192,19 @@ public SQLStatement parseSet() {
return stmt;
} else {
SQLSetStatement stmt = new SQLSetStatement(getDbType());

boolean mariadbSetStatementFlag = false;
if (JdbcUtils.isMysqlDbType(getDbType())) {
if (lexer.identifierEquals("STATEMENT")) {
mariadbSetStatementFlag = true;
lexer.nextToken();
}
}
parseAssignItems(stmt.getItems(), stmt, true);

if (mariadbSetStatementFlag) {
accept(Token.FOR);
SQLStatement maridbSetForStatement = this.parseStatement();
stmt.setMaridbSetForStatement(maridbSetForStatement);
}
if (global != null) {
SQLVariantRefExpr varRef = (SQLVariantRefExpr) stmt.getItems().get(0).getTarget();
varRef.setGlobal(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4165,6 +4165,9 @@ public boolean visit(SQLSetStatement x) {
if (printSet) {
print0(ucase ? "SET " : "set ");
}
if (x.getMaridbSetForStatement() != null) {
print0(ucase ? "STATEMENT " : "statement ");
}
SQLSetStatement.Option option = x.getOption();
if (option != null) {
print(option.name());
Expand All @@ -4177,6 +4180,10 @@ public boolean visit(SQLSetStatement x) {

printAndAccept(x.getItems(), ", ");

if (x.getMaridbSetForStatement() != null) {
print0(ucase ? " FOR " : " for ");
x.getMaridbSetForStatement().accept(this);
}
if (x.getHints() != null && x.getHints().size() > 0) {
print(' ');
printAndAccept(x.getHints(), " ");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package com.alibaba.druid.bvt.sql.mysql.issues;

import java.util.List;

import com.alibaba.druid.DbType;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.parser.SQLParserUtils;
import com.alibaba.druid.sql.parser.SQLStatementParser;
import com.alibaba.druid.sql.visitor.SchemaStatVisitor;

import org.junit.Test;

import static org.junit.Assert.assertEquals;

/**
* @author lizongbo
* @see <a href="https://github.com/alibaba/druid/issues/5861>Issue来源</a>
* @see <a href="https://mariadb.com/kb/en/set-statement/">SET STATEMENT</a>
*/
public class Issue5861 {

@Test
public void test_parse_set_statement() {
for (DbType dbType : new DbType[]{
DbType.mariadb,
DbType.mysql,

}) {

for (String sql : new String[]{
"SET STATEMENT max_statement_time=25 FOR select T.* from (\n"
+ "SELECT\n"
+ "head_pm_code\n"
+ "FROM\n"
+ "ef_ap_fee_detail\n"
+ "where audit_status = 0\n"
+ "and create_time >= '2023-12-02 00:00:00'\n"
+ "and create_time < '2023-12-03 00:00:00'\n"
+ "and (source_from='50' or status='30')\n"
+ "group by head_pm_code\n"
+ "limit 10000\n"
+ ") T;",
"SET STATEMENT join_cache_level=6, optimizer_switch='mrr=on' "
+ "FOR select * from t1 join t2 on t1.a=t2.a;",

}) {
System.out.println(dbType + "原始的sql===" + sql);
SQLStatementParser parser = SQLParserUtils.createSQLStatementParser(sql, dbType);
List<SQLStatement> statementList = parser.parseStatementList();
String sqlGen = statementList.toString();
System.out.println(dbType + "首次解析生成的sql===" + sqlGen);
StringBuilder sb = new StringBuilder();
for (SQLStatement statement : statementList) {
sb.append(statement.toString()).append(";");
}
sb.deleteCharAt(sb.length() - 1);
parser = SQLParserUtils.createSQLStatementParser(sb.toString(), dbType);
List<SQLStatement> statementListNew = parser.parseStatementList();
String sqlGenNew = statementList.toString();
System.out.println(dbType + "再次解析生成的sql===" + sqlGenNew);
assertEquals(statementList.toString(), statementListNew.toString());
}
}
}
}

0 comments on commit 0467284

Please sign in to comment.