diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/StatementParser.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/StatementParser.java index 335eb12534..3bfc67908f 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/StatementParser.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/StatementParser.java @@ -24,6 +24,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; import com.google.spanner.v1.ExecuteSqlRequest.QueryOptions; import java.util.Collections; import java.util.Objects; @@ -315,6 +316,10 @@ public boolean isQuery(String sql) { */ @InternalApi public boolean isUpdateStatement(String sql) { + // Skip any query hints at the beginning of the query. + if (sql.startsWith("@")) { + sql = removeStatementHint(sql); + } return statementStartsWith(sql, dmlStatements); } @@ -453,12 +458,16 @@ static String removeStatementHint(String sql) { // searching for the first occurrence of a keyword that should be preceded by a closing curly // brace at the end of the statement hint. int startStatementHintIndex = sql.indexOf('{'); - // Statement hints are only allowed for queries. + // Statement hints are allowed for both queries and DML statements. int startQueryIndex = -1; String upperCaseSql = sql.toUpperCase(); - for (String keyword : selectStatements) { + Set selectAndDmlStatements = + Sets.union(selectStatements, dmlStatements).immutableCopy(); + for (String keyword : selectAndDmlStatements) { startQueryIndex = upperCaseSql.indexOf(keyword); - if (startQueryIndex > -1) break; + if (startQueryIndex > -1) { + break; + } } if (startQueryIndex > -1) { int endStatementHintIndex = sql.substring(0, startQueryIndex).lastIndexOf('}'); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementParserTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementParserTest.java index edde69a05b..6e0fe00720 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementParserTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementParserTest.java @@ -317,7 +317,7 @@ public void testIsQuery() { } @Test - public void testQueryHints() { + public void testIsQuery_QueryHints() { // Valid query hints. assertTrue(parser.isQuery("@{JOIN_METHOD=HASH_JOIN} SELECT * FROM PersonsTable")); assertTrue(parser.isQuery("@ {JOIN_METHOD=HASH_JOIN} SELECT * FROM PersonsTable")); @@ -358,9 +358,51 @@ public void testQueryHints() { assertFalse(parser.isQuery("@{JOIN_METHOD=HASH_JOIN SELECT * FROM PersonsTable")); assertFalse(parser.isQuery("@JOIN_METHOD=HASH_JOIN} SELECT * FROM PersonsTable")); assertFalse(parser.isQuery("@JOIN_METHOD=HASH_JOIN SELECT * FROM PersonsTable")); + } + + @Test + public void testIsUpdate_QueryHints() { + // Valid query hints. + assertTrue( + parser.isUpdateStatement( + "@{LOCK_SCANNED_RANGES=exclusive} UPDATE FOO SET NAME='foo' WHERE ID=1")); + assertTrue( + parser.isUpdateStatement( + "@ {LOCK_SCANNED_RANGES=exclusive} UPDATE FOO SET NAME='foo' WHERE ID=1")); + assertTrue( + parser.isUpdateStatement( + "@{ LOCK_SCANNED_RANGES=exclusive} UPDATE FOO SET NAME='foo' WHERE ID=1")); + assertTrue( + parser.isUpdateStatement( + "@{LOCK_SCANNED_RANGES=exclusive } UPDATE FOO SET NAME='foo' WHERE ID=1")); + assertTrue( + parser.isUpdateStatement( + "@{LOCK_SCANNED_RANGES=exclusive}\nUPDATE FOO SET NAME='foo' WHERE ID=1")); + assertTrue( + parser.isUpdateStatement( + "@{\nLOCK_SCANNED_RANGES = exclusive \t}\n\t UPDATE FOO SET NAME='foo' WHERE ID=1")); + assertTrue( + parser.isUpdateStatement( + "@{LOCK_SCANNED_RANGES=exclusive}\n -- Single line comment\nUPDATE FOO SET NAME='foo' WHERE ID=1")); + assertTrue( + parser.isUpdateStatement( + "@{LOCK_SCANNED_RANGES=exclusive}\n /* Multi line comment\n with more comments\n */UPDATE FOO SET NAME='foo' WHERE ID=1")); + + // Multiple query hints. + assertTrue( + StatementParser.INSTANCE.isUpdateStatement( + "@{LOCK_SCANNED_RANGES=exclusive} @{USE_ADDITIONAL_PARALLELISM=TRUE} UPDATE FOO SET NAME='foo' WHERE ID=1")); + + // Invalid query hints. assertFalse( - StatementParser.INSTANCE.isQuery( - "@{FORCE_INDEX=index_name} @{JOIN_METHOD=HASH_JOIN} UPDATE tbl set FOO=1 WHERE ID=2")); + parser.isUpdateStatement( + "@{LOCK_SCANNED_RANGES=exclusive UPDATE FOO SET NAME='foo' WHERE ID=1")); + assertFalse( + parser.isUpdateStatement( + "@LOCK_SCANNED_RANGES=exclusive} UPDATE FOO SET NAME='foo' WHERE ID=1")); + assertFalse( + parser.isUpdateStatement( + "@LOCK_SCANNED_RANGES=exclusive UPDATE FOO SET NAME='foo' WHERE ID=1")); } @Test