Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48027][SQL] InjectRuntimeFilter for multi-level join should check child join type #46263

Closed
wants to merge 9 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.annotation.tailrec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF}
Expand Down Expand Up @@ -86,6 +87,17 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
private def extractSelectiveFilterOverScan(
plan: LogicalPlan,
filterCreationSideKey: Expression): Option[(Expression, LogicalPlan)] = {

def canExtractLeft(joinType: JoinType): Boolean = joinType match {
case Inner | LeftSemi | LeftOuter | LeftAnti => true
case _ => false
}

def canExtractRight(joinType: JoinType): Boolean = joinType match {
case Inner | RightOuter => true
case _ => false
}

def extract(
p: LogicalPlan,
predicateReference: AttributeSet,
Expand Down Expand Up @@ -120,33 +132,61 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
hasHitSelectiveFilter = hasHitSelectiveFilter || isLikelySelective(condition),
currentPlan,
targetKey)
case ExtractEquiJoinKeys(_, lkeys, rkeys, _, _, left, right, _) =>
case ExtractEquiJoinKeys(joinType, lkeys, rkeys, _, _, left, right, _) =>
// Runtime filters use one side of the [[Join]] to build a set of join key values and prune
// the other side of the [[Join]]. It's also OK to use a superset of the join key values
// (ignore null values) to do the pruning.
Copy link
Contributor

@cloud-fan cloud-fan Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm glad we described the idea in the comments. It's clear that for certain join types, the join child output is not a superset of the join output for transitive join keys.

// We assume other rules have already pushed predicates through join if possible.
// So the predicate references won't pass on anymore.
if (left.output.exists(_.semanticEquals(targetKey))) {
extract(left, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false,
Copy link
Contributor

@cloud-fan cloud-fan Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's clarify the expected strategy a bit more: For the exact join key match, like the left table here, it's always OK to generate the runtime filter using this left table, no matter what the join type is. This is because left table always produce a superset of output of the join output regarding the left keys.

For transitive join key match, it's different. The right table here does not always generate a superset output regarding left keys. Let's look at an example

left table: 1, 2, 3
right table, 3, 4
left outer join output: (1, null), (2, null), (3, 3)
left keys: 1, 2, 3

So we can't use right table to generate runtime filter.

currentPlan = left, targetKey = targetKey).orElse {
// We can also extract from the right side if the join keys are transitive.
AngersZhuuuu marked this conversation as resolved.
Show resolved Hide resolved
lkeys.zip(rkeys).find(_._1.semanticEquals(targetKey)).map(_._2)
.flatMap { newTargetKey =>
extract(right, AttributeSet.empty,
hasHitFilter = false, hasHitSelectiveFilter = false, currentPlan = right,
targetKey = newTargetKey)
}
if (canExtractLeft(joinType)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I explained at https://github.com/apache/spark/pull/46263/files#r1582495367 , we don't need this check here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extract equi join condition, make sense. the ELSE branch also need to remove canExtractLeft, right?

extract(left, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false,
currentPlan = left, targetKey = targetKey)
} else {
None
}.orElse {
// For the exact join key match, like the left table here, it's always OK to generate
// the runtime filter using this left table, no matter what the join type is.
// This is because left table always produce a superset of output of the join output
// regarding the left keys.
// For transitive join key match, it's different. The right table here does
// not always generate a superset output regarding left keys.
// Let's look at an example
// left table: 1, 2, 3
// right table, 3, 4
// left outer join output: (1, null), (2, null), (3, 3)
// left keys: 1, 2, 3
// So we can't use right table to generate runtime filter.
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
if (canExtractRight(joinType)) {
lkeys.zip(rkeys).find(_._1.semanticEquals(targetKey)).map(_._2)
.flatMap { newTargetKey =>
extract(right, AttributeSet.empty,
hasHitFilter = false, hasHitSelectiveFilter = false, currentPlan = right,
targetKey = newTargetKey)
}
} else {
None
}
}
} else if (right.output.exists(_.semanticEquals(targetKey))) {
extract(right, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false,
currentPlan = right, targetKey = targetKey).orElse {
// We can also extract from the left side if the join keys are transitive.
rkeys.zip(lkeys).find(_._1.semanticEquals(targetKey)).map(_._2)
.flatMap { newTargetKey =>
extract(left, AttributeSet.empty,
hasHitFilter = false, hasHitSelectiveFilter = false, currentPlan = left,
targetKey = newTargetKey)
}
if (canExtractRight(joinType)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the code is wrong here. We only need this extra check for transitive join keys. In this branch, it's the left table.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we can just canPruneLeft and canPruneRight, it's have the same rule

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like current ?

extract(right, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false,
currentPlan = right, targetKey = targetKey)
} else {
None
}.orElse {
if (canExtractLeft(joinType)) {
// We can also extract from the left side if the join keys are transitive.
rkeys.zip(lkeys).find(_._1.semanticEquals(targetKey)).map(_._2)
.flatMap { newTargetKey =>
extract(left, AttributeSet.empty,
hasHitFilter = false, hasHitSelectiveFilter = false, currentPlan = left,
targetKey = newTargetKey)
}
}
else {
None
}
}
} else {
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp
"(bf1.c1 = bf2.c2 and bf2.a2 = 5)) as a join bf3 on bf3.c3 = a.c1", 2)
// left anti join unsupported.
// bf2 as creation side and inject runtime filter for bf3(by passing key).
assertRewroteWithBloomFilter("select * from (select * from bf1 left anti join bf2 on " +
"(bf1.c1 = bf2.c2 and bf2.a2 = 5)) as a join bf3 on bf3.c3 = a.c1")
assertDidNotRewriteWithBloomFilter("select * from (select * from bf1 left anti join bf2 " +
"on (bf1.c1 = bf2.c2 and bf2.a2 = 5)) as a join bf3 on bf3.c3 = a.c1")
// left anti join unsupported and hasn't selective filter.
assertRewroteWithBloomFilter("select * from (select * from bf1 left anti join bf2 on " +
"(bf1.c1 = bf2.c2 and bf1.a1 = 5)) as a join bf3 on bf3.c3 = a.c1", 0)
Expand Down