Skip to content

Commit

Permalink
[FLINK-27741][table-planner] Fix NPE when use dense_rank() and rank()
Browse files Browse the repository at this point in the history
Co-authored-by: Sergey Nuyanzin <snuyanzin@gmail.com>

This closes #19797

Co-authored-by: chenzihao <chenzihao5@xiaomi.com>
  • Loading branch information
snuyanzin and chenzihao committed May 15, 2024
1 parent e16da86 commit 190522c
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ protected Expression orderKeyEqualsExpression() {
equalTo(lasValue, operand(i)));
}
Optional<Expression> ret = Arrays.stream(orderKeyEquals).reduce(ExpressionBuilder::and);
return ret.orElseGet(() -> literal(true));
return ret.orElseGet(() -> literal(false));
}

protected Expression generateInitLiteral(LogicalType orderType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,18 +529,23 @@ class AggFunctionFactory(
}

private def createRankAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = {
val argTypes = orderKeyIndexes.map(inputRowType.getChildren.get(_))
new RankAggFunction(argTypes)
new RankAggFunction(getArgTypesOrEmpty())
}

private def createDenseRankAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = {
val argTypes = orderKeyIndexes.map(inputRowType.getChildren.get(_))
new DenseRankAggFunction(argTypes)
new DenseRankAggFunction(getArgTypesOrEmpty())
}

private def createPercentRankAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = {
val argTypes = orderKeyIndexes.map(inputRowType.getChildren.get(_))
new PercentRankAggFunction(argTypes)
new PercentRankAggFunction(getArgTypesOrEmpty())
}

private def getArgTypesOrEmpty(): Array[LogicalType] = {
if (orderKeyIndexes != null) {
orderKeyIndexes.map(inputRowType.getChildren.get(_))
} else {
Array[LogicalType]()
}
}

private def createNTILEAggFUnction(argTypes: Array[LogicalType]): UserDefinedFunction = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,50 @@ OverAggregate(partitionBy=[c], window#0=[COUNT(*) AS w0$o0 RANG BETWEEN UNBOUNDE
]]>
</Resource>
</TestCase>
<TestCase name="testDenseRankOnOrder">
<Resource name="sql">
<![CDATA[SELECT a, DENSE_RANK() OVER (PARTITION BY a ORDER BY proctime) FROM MyTableWithProctime]]>
</Resource>
<Resource name="ast">
<![CDATA[
LogicalProject(a=[$0], EXPR$1=[DENSE_RANK() OVER (PARTITION BY $0 ORDER BY $3 NULLS FIRST)])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTableWithProctime, source: [TestTableSource(a, b, c, proctime)]]])
]]>
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
Calc(select=[a, w0$o0 AS $1])
+- OverAggregate(partitionBy=[a], orderBy=[proctime ASC], window#0=[DENSE_RANK(*) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, proctime, w0$o0])
+- Exchange(distribution=[forward])
+- Sort(orderBy=[a ASC, proctime ASC])
+- Exchange(distribution=[hash[a]])
+- Calc(select=[a, proctime])
+- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTableWithProctime, source: [TestTableSource(a, b, c, proctime)]]], fields=[a, b, c, proctime])
]]>
</Resource>
</TestCase>
<TestCase name="testRankOnOver">
<Resource name="sql">
<![CDATA[SELECT a, RANK() OVER (PARTITION BY a ORDER BY proctime) FROM MyTableWithProctime]]>
</Resource>
<Resource name="ast">
<![CDATA[
LogicalProject(a=[$0], EXPR$1=[RANK() OVER (PARTITION BY $0 ORDER BY $3 NULLS FIRST)])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTableWithProctime, source: [TestTableSource(a, b, c, proctime)]]])
]]>
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
Calc(select=[a, w0$o0 AS $1])
+- OverAggregate(partitionBy=[a], orderBy=[proctime ASC], window#0=[RANK(*) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, proctime, w0$o0])
+- Exchange(distribution=[forward])
+- Sort(orderBy=[a ASC, proctime ASC])
+- Exchange(distribution=[hash[a]])
+- Calc(select=[a, proctime])
+- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTableWithProctime, source: [TestTableSource(a, b, c, proctime)]]], fields=[a, b, c, proctime])
]]>
</Resource>
</TestCase>
<TestCase name="testOverWindowWithoutPartitionBy">
<Resource name="sql">
<![CDATA[SELECT c, SUM(a) OVER (ORDER BY b) FROM MyTable]]>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class OverAggregateTest extends TableTestBase {

private val util = batchTestUtil()
util.addTableSource[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
util.addTableSource[(Int, Long, String, Long)]("MyTableWithProctime", 'a, 'b, 'c, 'proctime)

@Test
def testOverWindowWithoutPartitionByOrderBy(): Unit = {
Expand All @@ -47,6 +48,18 @@ class OverAggregateTest extends TableTestBase {
util.verifyExecPlan("SELECT c, SUM(a) OVER (ORDER BY b) FROM MyTable")
}

@Test
def testDenseRankOnOrder(): Unit = {
util.verifyExecPlan(
"SELECT a, DENSE_RANK() OVER (PARTITION BY a ORDER BY proctime) FROM MyTableWithProctime")
}

@Test
def testRankOnOver(): Unit = {
util.verifyExecPlan(
"SELECT a, RANK() OVER (PARTITION BY a ORDER BY proctime) FROM MyTableWithProctime")
}

@Test
def testDiffPartitionKeysWithSameOrderKeys(): Unit = {
val sqlQuery =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,66 @@ class OverAggregateITCase(mode: StateBackendMode) extends StreamingWithStateTest
assertThat(sink.getAppendResults.sorted).isEqualTo(expected.sorted)
}

@TestTemplate
def testDenseRankOnOver(): Unit = {
val t = failingDataSource(TestData.tupleData5)
.toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime.proctime)
tEnv.createTemporaryView("MyTable", t)
val sqlQuery = "SELECT a, DENSE_RANK() OVER (PARTITION BY a ORDER BY proctime) FROM MyTable"

val sink = new TestingAppendSink
tEnv.sqlQuery(sqlQuery).toDataStream.addSink(sink)
env.execute()

val expected = List(
"1,1",
"2,1",
"2,2",
"3,1",
"3,2",
"3,3",
"4,1",
"4,2",
"4,3",
"4,4",
"5,1",
"5,2",
"5,3",
"5,4",
"5,5")
assertThat(expected.sorted).isEqualTo(sink.getAppendResults.sorted)
}

@TestTemplate
def testRankOnOver(): Unit = {
val t = failingDataSource(TestData.tupleData5)
.toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime.proctime)
tEnv.createTemporaryView("MyTable", t)
val sqlQuery = "SELECT a, RANK() OVER (PARTITION BY a ORDER BY proctime) FROM MyTable"

val sink = new TestingAppendSink
tEnv.sqlQuery(sqlQuery).toDataStream.addSink(sink)
env.execute()

val expected = List(
"1,1",
"2,1",
"2,2",
"3,1",
"3,2",
"3,3",
"4,1",
"4,2",
"4,3",
"4,4",
"5,1",
"5,2",
"5,3",
"5,4",
"5,5")
assertThat(expected.sorted).isEqualTo(sink.getAppendResults.sorted)
}

@TestTemplate
def testProcTimeBoundedPartitionedRowsOver(): Unit = {
val t = failingDataSource(TestData.tupleData5)
Expand Down

0 comments on commit 190522c

Please sign in to comment.