Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48027][SQL] InjectRuntimeFilter for multi-level join should check child join type #46263

Closed
wants to merge 9 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.annotation.tailrec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF}
Expand Down Expand Up @@ -86,6 +87,17 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
private def extractSelectiveFilterOverScan(
plan: LogicalPlan,
filterCreationSideKey: Expression): Option[(Expression, LogicalPlan)] = {

def canExtractRight(joinType: JoinType): Boolean = joinType match {
case Inner | LeftSemi | RightOuter => true
case _ => false
}

def canExtractLeft(joinType: JoinType): Boolean = joinType match {
case Inner | LeftSemi | LeftOuter | LeftAnti => true
case _ => false
}

def extract(
p: LogicalPlan,
predicateReference: AttributeSet,
Expand Down Expand Up @@ -120,34 +132,49 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
hasHitSelectiveFilter = hasHitSelectiveFilter || isLikelySelective(condition),
currentPlan,
targetKey)
case ExtractEquiJoinKeys(_, lkeys, rkeys, _, _, left, right, _) =>
case ExtractEquiJoinKeys(joinType, lkeys, rkeys, _, _, left, right, _) =>
// Runtime filters use one side of the [[Join]] to build a set of join key values and prune
// the other side of the [[Join]]. It's also OK to use a superset of the join key values
// (ignore null values) to do the pruning.
Copy link
Contributor

@cloud-fan cloud-fan Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm glad we described the idea in the comments. It's clear that for certain join types, the join child output is not a superset of the join output for transitive join keys.

// We assume other rules have already pushed predicates through join if possible.
// So the predicate references won't pass on anymore.
if (left.output.exists(_.semanticEquals(targetKey))) {
extract(left, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false,
Copy link
Contributor

@cloud-fan cloud-fan Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's clarify the expected strategy a bit more: For the exact join key match, like the left table here, it's always OK to generate the runtime filter using this left table, no matter what the join type is. This is because left table always produce a superset of output of the join output regarding the left keys.

For transitive join key match, it's different. The right table here does not always generate a superset output regarding left keys. Let's look at an example

left table: 1, 2, 3
right table, 3, 4
left outer join output: (1, null), (2, null), (3, 3)
left keys: 1, 2, 3

So we can't use right table to generate runtime filter.

currentPlan = left, targetKey = targetKey).orElse {
// We can also extract from the right side if the join keys are transitive.
AngersZhuuuu marked this conversation as resolved.
Show resolved Hide resolved
val extractLeft = if (canExtractLeft(joinType)) {
extract(left, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false,
currentPlan = left, targetKey = targetKey)
} else {
None
}
val extractRight = if (canExtractRight(joinType)) {
lkeys.zip(rkeys).find(_._1.semanticEquals(targetKey)).map(_._2)
.flatMap { newTargetKey =>
extract(right, AttributeSet.empty,
hasHitFilter = false, hasHitSelectiveFilter = false, currentPlan = right,
targetKey = newTargetKey)
}
} else {
None
}
extractLeft.orElse(extractRight)
AngersZhuuuu marked this conversation as resolved.
Show resolved Hide resolved
} else if (right.output.exists(_.semanticEquals(targetKey))) {
extract(right, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false,
currentPlan = right, targetKey = targetKey).orElse {
val extractRight = if (canExtractRight(joinType)) {
extract(right, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false,
currentPlan = right, targetKey = targetKey)
} else {
None
}
val extractLeft = if (canExtractLeft(joinType)) {
// We can also extract from the left side if the join keys are transitive.
rkeys.zip(lkeys).find(_._1.semanticEquals(targetKey)).map(_._2)
.flatMap { newTargetKey =>
extract(left, AttributeSet.empty,
hasHitFilter = false, hasHitSelectiveFilter = false, currentPlan = left,
targetKey = newTargetKey)
}
} else {
None
}
extractRight.orElse(extractLeft)
AngersZhuuuu marked this conversation as resolved.
Show resolved Hide resolved
} else {
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp
"(bf1.c1 = bf2.c2 and bf2.a2 = 5)) as a join bf3 on bf3.c3 = a.c1", 2)
// left anti join unsupported.
// bf2 as creation side and inject runtime filter for bf3(by passing key).
assertRewroteWithBloomFilter("select * from (select * from bf1 left anti join bf2 on " +
"(bf1.c1 = bf2.c2 and bf2.a2 = 5)) as a join bf3 on bf3.c3 = a.c1")
assertDidNotRewriteWithBloomFilter("select * from (select * from bf1 left anti join bf2 " +
"on (bf1.c1 = bf2.c2 and bf2.a2 = 5)) as a join bf3 on bf3.c3 = a.c1")
// left anti join unsupported and hasn't selective filter.
assertRewroteWithBloomFilter("select * from (select * from bf1 left anti join bf2 on " +
"(bf1.c1 = bf2.c2 and bf1.a1 = 5)) as a join bf3 on bf3.c3 = a.c1", 0)
Expand Down