Skip to content

Commit

Permalink
simplify returns closure
Browse files Browse the repository at this point in the history
  • Loading branch information
milenkovicm committed May 10, 2024
1 parent de51434 commit a678e6d
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 61 deletions.
39 changes: 19 additions & 20 deletions datafusion-examples/examples/simplify_udaf_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use arrow_schema::{Field, Schema};
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_common::tree_node::Transformed;
use datafusion_expr::function::AggregateFunctionSimplification;
use datafusion_expr::simplify::SimplifyInfo;

use std::{any::Any, sync::Arc};
Expand Down Expand Up @@ -88,27 +88,26 @@ impl AggregateUDFImpl for BetterAvgUdaf {
}
// we override method, to return new expression which would substitute
// user defined function call
fn simplify(
&self,
aggregate_function: AggregateFunction,
_info: &dyn SimplifyInfo,
) -> Result<Transformed<Expr>> {
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
// as an example for this functionality we replace UDF function
// with build-in aggregate function to illustrate the use
let expr = Expr::AggregateFunction(AggregateFunction {
func_def: AggregateFunctionDefinition::BuiltIn(
// yes it is the same Avg, `BetterAvgUdaf` was just a
// marketing pitch :)
datafusion_expr::aggregate_function::AggregateFunction::Avg,
),
args: aggregate_function.args,
distinct: aggregate_function.distinct,
filter: aggregate_function.filter,
order_by: aggregate_function.order_by,
null_treatment: aggregate_function.null_treatment,
});

Ok(Transformed::yes(expr))
let simplify = |aggregate_function: datafusion_expr::expr::AggregateFunction,
_: &dyn SimplifyInfo| {
Ok(Expr::AggregateFunction(AggregateFunction {
func_def: AggregateFunctionDefinition::BuiltIn(
// yes it is the same Avg, `BetterAvgUdaf` was just a
// marketing pitch :)
datafusion_expr::aggregate_function::AggregateFunction::Avg,
),
args: aggregate_function.args,
distinct: aggregate_function.distinct,
filter: aggregate_function.filter,
order_by: aggregate_function.order_by,
null_treatment: aggregate_function.null_treatment,
}))
};

Some(Box::new(simplify))
}
}

Expand Down
13 changes: 13 additions & 0 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,16 @@ pub type PartitionEvaluatorFactory =
/// its state, given its return datatype.
pub type StateTypeFunction =
Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;

/// [crate::udaf::AggregateUDFImpl::simplify] simplifier closure
/// A closure with two arguments:
/// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked
/// * 'info': [crate::simplify::SimplifyInfo]
///
/// closure returns simplified [Expr] or an error.
pub type AggregateFunctionSimplification = Box<
dyn Fn(
crate::expr::AggregateFunction,
&dyn crate::simplify::SimplifyInfo,
) -> Result<Expr>,
>;
41 changes: 15 additions & 26 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@

//! [`AggregateUDF`]: User Defined Aggregate Functions

use crate::expr::AggregateFunction;
use crate::function::AccumulatorArgs;
use crate::function::{AccumulatorArgs, AggregateFunctionSimplification};
use crate::groups_accumulator::GroupsAccumulator;
use crate::simplify::SimplifyInfo;
use crate::utils::format_state_name;
use crate::{Accumulator, Expr};
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
use arrow::datatypes::{DataType, Field};
use datafusion_common::tree_node::Transformed;
use datafusion_common::{not_impl_err, Result};
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
Expand Down Expand Up @@ -201,12 +198,8 @@ impl AggregateUDF {
/// Do the function rewrite
///
/// See [`AggregateUDFImpl::simplify`] for more details.
pub fn simplify(
&self,
aggregate_function: AggregateFunction,
info: &dyn SimplifyInfo,
) -> Result<Transformed<Expr>> {
self.inner.simplify(aggregate_function, info)
pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
self.inner.simplify()
}
}

Expand Down Expand Up @@ -368,7 +361,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
&[]
}

/// Optionally apply per-UDF simplification / rewrite rules.
/// Optionally apply per-UDaF simplification / rewrite rules.
///
/// This can be used to apply function specific simplification rules during
/// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
Expand All @@ -379,22 +372,18 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
/// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
/// optimizations manually for specific UDFs.
///
/// # Arguments
/// * 'aggregate_function': Aggregate function to be simplified
/// * 'info': Simplification information
///
/// # Returns
/// [`Transformed`] indicating the result of the simplification NOTE
/// if the function cannot be simplified, [Expr::AggregateFunction] with unmodified [AggregateFunction]
/// should be returned
fn simplify(
&self,
aggregate_function: AggregateFunction,
_info: &dyn SimplifyInfo,
) -> Result<Transformed<Expr>> {
Ok(Transformed::yes(Expr::AggregateFunction(
aggregate_function,
)))
///
/// [None] if simplify is not defined or,
///
/// Or, a closure with two arguments:
/// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked
/// * 'info': [crate::simplify::SimplifyInfo]
///
/// closure returns simplified [Expr] or an error.
///
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
None
}
}

Expand Down
26 changes: 11 additions & 15 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1385,14 +1385,12 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
func_def: AggregateFunctionDefinition::UDF(ref udaf),
..
}) => {
let udaf = udaf.clone();
if let Expr::AggregateFunction(aggregate_function) = expr {
udaf.simplify(aggregate_function, info)?
} else {
unreachable!("this branch should be unreachable")
}) => match (udaf.simplify(), expr) {
(Some(simplify_function), Expr::AggregateFunction(af)) => {
Transformed::yes(simplify_function(af, info)?)
}
}
(_, expr) => Transformed::no(expr),
},

//
// Rules for Between
Expand Down Expand Up @@ -1760,7 +1758,9 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result<Expr> {
#[cfg(test)]
mod tests {
use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
use datafusion_expr::{interval_arithmetic::Interval, *};
use datafusion_expr::{
function::AggregateFunctionSimplification, interval_arithmetic::Interval, *,
};
use std::{
collections::HashMap,
ops::{BitAnd, BitOr, BitXor},
Expand Down Expand Up @@ -3791,15 +3791,11 @@ mod tests {
unimplemented!("not needed for testing")
}

fn simplify(
&self,
aggregate_function: datafusion_expr::expr::AggregateFunction,
_info: &dyn SimplifyInfo,
) -> Result<Transformed<Expr>> {
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
if self.simplify {
Ok(Transformed::yes(col("result_column")))
Some(Box::new(|_, _| Ok(col("result_column"))))
} else {
Ok(Transformed::no(Expr::AggregateFunction(aggregate_function)))
None
}
}
}
Expand Down

0 comments on commit a678e6d

Please sign in to comment.