Skip to content

Commit

Permalink
Optimized push down filter #10291 (#10366)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitrybugakov committed May 3, 2024
1 parent 2c56a3c commit 8190cb9
Showing 1 changed file with 81 additions and 58 deletions.
139 changes: 81 additions & 58 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};
use itertools::Itertools;

use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
Expand All @@ -29,6 +28,7 @@ use datafusion_common::{
};
use datafusion_expr::expr::Alias;
use datafusion_expr::expr_rewriter::replace_col;
use datafusion_expr::logical_plan::tree_node::unwrap_arc;
use datafusion_expr::logical_plan::{
CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union,
};
Expand All @@ -38,7 +38,8 @@ use datafusion_expr::{
ScalarFunctionDefinition, TableProviderFilterPushDown,
};

use itertools::Itertools;
use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};

/// Optimizer rule for pushing (moving) filter expressions down in a plan so
/// they are applied as early as possible.
Expand Down Expand Up @@ -407,7 +408,7 @@ fn push_down_all_join(
right: &LogicalPlan,
on_filter: Vec<Expr>,
is_inner_join: bool,
) -> Result<LogicalPlan> {
) -> Result<Transformed<LogicalPlan>> {
let on_filter_empty = on_filter.is_empty();
// Get pushable predicates from current optimizer state
let (left_preserved, right_preserved) = lr_is_preserved(join_plan)?;
Expand Down Expand Up @@ -505,41 +506,43 @@ fn push_down_all_join(
// wrap the join on the filter whose predicates must be kept
match conjunction(keep_predicates) {
Some(predicate) => {
Filter::try_new(predicate, Arc::new(plan)).map(LogicalPlan::Filter)
let new_filter_plan = Filter::try_new(predicate, Arc::new(plan))?;
Ok(Transformed::yes(LogicalPlan::Filter(new_filter_plan)))
}
None => Ok(plan),
None => Ok(Transformed::no(plan)),
}
}

fn push_down_join(
plan: &LogicalPlan,
join: &Join,
parent_predicate: Option<&Expr>,
) -> Result<Option<LogicalPlan>> {
let predicates = match parent_predicate {
Some(parent_predicate) => split_conjunction_owned(parent_predicate.clone()),
None => vec![],
};
) -> Result<Transformed<LogicalPlan>> {
// Split the parent predicate into individual conjunctive parts.
let predicates = parent_predicate
.map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));

// Convert JOIN ON predicate to Predicates
// Extract conjunctions from the JOIN's ON filter, if present.
let on_filters = join
.filter
.as_ref()
.map(|e| split_conjunction_owned(e.clone()))
.unwrap_or_default();
.map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));

let mut is_inner_join = false;
let infer_predicates = if join.join_type == JoinType::Inner {
is_inner_join = true;

// Only allow both side key is column.
let join_col_keys = join
.on
.iter()
.flat_map(|(l, r)| match (l.try_into_col(), r.try_into_col()) {
(Ok(l_col), Ok(r_col)) => Some((l_col, r_col)),
_ => None,
.filter_map(|(l, r)| {
let left_col = l.try_into_col().ok()?;
let right_col = r.try_into_col().ok()?;
Some((left_col, right_col))
})
.collect::<Vec<_>>();

// TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down
// For inner joins, duplicate filters for joined columns so filters can be pushed down
// to both sides. Take the following query as an example:
Expand All @@ -559,6 +562,7 @@ fn push_down_join(
.chain(on_filters.iter())
.filter_map(|predicate| {
let mut join_cols_to_replace = HashMap::new();

let columns = match predicate.to_columns() {
Ok(columns) => columns,
Err(e) => return Some(Err(e)),
Expand Down Expand Up @@ -596,20 +600,32 @@ fn push_down_join(
};

if on_filters.is_empty() && predicates.is_empty() && infer_predicates.is_empty() {
return Ok(None);
return Ok(Transformed::no(plan.clone()));
}
Ok(Some(push_down_all_join(

match push_down_all_join(
predicates,
infer_predicates,
plan,
&join.left,
&join.right,
on_filters,
is_inner_join,
)?))
) {
Ok(plan) => Ok(Transformed::yes(plan.data)),
Err(e) => Err(e),
}
}

impl OptimizerRule for PushDownFilter {
fn try_optimize(
&self,
_plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
internal_err!("Should have called PushDownFilter::rewrite")
}

fn name(&self) -> &str {
"push_down_filter"
}
Expand All @@ -618,21 +634,24 @@ impl OptimizerRule for PushDownFilter {
Some(ApplyOrder::TopDown)
}

fn try_optimize(
fn supports_rewrite(&self) -> bool {
true
}

fn rewrite(
&self,
plan: &LogicalPlan,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
) -> Result<Transformed<LogicalPlan>> {
let filter = match plan {
LogicalPlan::Filter(filter) => filter,
// we also need to pushdown filter in Join.
LogicalPlan::Join(join) => return push_down_join(plan, join, None),
_ => return Ok(None),
LogicalPlan::Filter(ref filter) => filter,
LogicalPlan::Join(ref join) => return push_down_join(&plan, join, None),
_ => return Ok(Transformed::no(plan)),
};

let child_plan = filter.input.as_ref();
let new_plan = match child_plan {
LogicalPlan::Filter(child_filter) => {
LogicalPlan::Filter(ref child_filter) => {
let parents_predicates = split_conjunction(&filter.predicate);
let set: HashSet<&&Expr> = parents_predicates.iter().collect();

Expand All @@ -652,20 +671,18 @@ impl OptimizerRule for PushDownFilter {
new_predicate,
child_filter.input.clone(),
)?);
self.try_optimize(&new_filter, _config)?
.unwrap_or(new_filter)
self.rewrite(new_filter, _config)?.data
}
LogicalPlan::Repartition(_)
| LogicalPlan::Distinct(_)
| LogicalPlan::Sort(_) => {
// commutable
let new_filter = plan.with_new_exprs(
plan.expressions(),
vec![child_plan.inputs()[0].clone()],
)?;
child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])?
}
LogicalPlan::SubqueryAlias(subquery_alias) => {
LogicalPlan::SubqueryAlias(ref subquery_alias) => {
let mut replace_map = HashMap::new();
for (i, (qualifier, field)) in
subquery_alias.input.schema().iter().enumerate()
Expand All @@ -685,7 +702,7 @@ impl OptimizerRule for PushDownFilter {
)?);
child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])?
}
LogicalPlan::Projection(projection) => {
LogicalPlan::Projection(ref projection) => {
// A projection is filter-commutable if it do not contain volatile predicates or contain volatile
// predicates that are not used in the filter. However, we should re-writes all predicate expressions.
// collect projection.
Expand Down Expand Up @@ -742,10 +759,10 @@ impl OptimizerRule for PushDownFilter {
}
}
}
None => return Ok(None),
None => return Ok(Transformed::no(plan)),
}
}
LogicalPlan::Union(union) => {
LogicalPlan::Union(ref union) => {
let mut inputs = Vec::with_capacity(union.inputs.len());
for input in &union.inputs {
let mut replace_map = HashMap::new();
Expand All @@ -770,7 +787,7 @@ impl OptimizerRule for PushDownFilter {
schema: plan.schema().clone(),
})
}
LogicalPlan::Aggregate(agg) => {
LogicalPlan::Aggregate(ref agg) => {
// We can push down Predicate which in groupby_expr.
let group_expr_columns = agg
.group_expr
Expand Down Expand Up @@ -821,13 +838,15 @@ impl OptimizerRule for PushDownFilter {
None => new_agg,
}
}
LogicalPlan::Join(join) => {
match push_down_join(&filter.input, join, Some(&filter.predicate))? {
Some(optimized_plan) => optimized_plan,
None => return Ok(None),
}
LogicalPlan::Join(ref join) => {
push_down_join(
&unwrap_arc(filter.clone().input),
join,
Some(&filter.predicate),
)?
.data
}
LogicalPlan::CrossJoin(cross_join) => {
LogicalPlan::CrossJoin(ref cross_join) => {
let predicates = split_conjunction_owned(filter.predicate.clone());
let join = convert_cross_join_to_inner_join(cross_join.clone())?;
let join_plan = LogicalPlan::Join(join);
Expand All @@ -843,9 +862,9 @@ impl OptimizerRule for PushDownFilter {
vec![],
true,
)?;
convert_to_cross_join_if_beneficial(plan)?
convert_to_cross_join_if_beneficial(plan.data)?
}
LogicalPlan::TableScan(scan) => {
LogicalPlan::TableScan(ref scan) => {
let filter_predicates = split_conjunction(&filter.predicate);
let results = scan
.source
Expand Down Expand Up @@ -892,7 +911,7 @@ impl OptimizerRule for PushDownFilter {
None => new_scan,
}
}
LogicalPlan::Extension(extension_plan) => {
LogicalPlan::Extension(ref extension_plan) => {
let prevent_cols =
extension_plan.node.prevent_predicate_push_down_columns();

Expand Down Expand Up @@ -935,9 +954,10 @@ impl OptimizerRule for PushDownFilter {
None => new_extension,
}
}
_ => return Ok(None),
_ => return Ok(Transformed::no(plan)),
};
Ok(Some(new_plan))

Ok(Transformed::yes(new_plan))
}
}

Expand Down Expand Up @@ -1024,16 +1044,12 @@ fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {

#[cfg(test)]
mod tests {
use super::*;
use std::any::Any;
use std::fmt::{Debug, Formatter};

use crate::optimizer::Optimizer;
use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use crate::test::*;
use crate::OptimizerContext;

use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;

use datafusion_common::ScalarValue;
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::logical_plan::table_scan;
Expand All @@ -1043,7 +1059,13 @@ mod tests {
Volatility,
};

use async_trait::async_trait;
use crate::optimizer::Optimizer;
use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use crate::test::*;
use crate::OptimizerContext;

use super::*;

fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}

fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
Expand Down Expand Up @@ -2298,9 +2320,9 @@ mod tests {
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;

let optimized_plan = PushDownFilter::new()
.try_optimize(&plan, &OptimizerContext::new())
.rewrite(plan, &OptimizerContext::new())
.expect("failed to optimize plan")
.unwrap();
.data;

let expected = "\
Filter: a = Int64(1)\
Expand Down Expand Up @@ -2667,8 +2689,9 @@ Projection: a, b
// Originally global state which can help to avoid duplicate Filters been generated and pushed down.
// Now the global state is removed. Need to double confirm that avoid duplicate Filters.
let optimized_plan = PushDownFilter::new()
.try_optimize(&plan, &OptimizerContext::new())?
.expect("failed to optimize plan");
.rewrite(plan, &OptimizerContext::new())
.expect("failed to optimize plan")
.data;
assert_optimized_plan_eq(optimized_plan, expected)
}

Expand Down

0 comments on commit 8190cb9

Please sign in to comment.