Skip to content

Commit

Permalink
Add simplify method to aggregate function (#10354)
Browse files Browse the repository at this point in the history
* add simplify method for aggregate function

* simplify returns closure
  • Loading branch information
milenkovicm committed May 13, 2024
1 parent 58cc4e1 commit 230c68c
Show file tree
Hide file tree
Showing 4 changed files with 328 additions and 3 deletions.
180 changes: 180 additions & 0 deletions datafusion-examples/examples/simplify_udaf_expression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow_schema::{Field, Schema};
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_expr::function::AggregateFunctionSimplification;
use datafusion_expr::simplify::SimplifyInfo;

use std::{any::Any, sync::Arc};

use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch};
use datafusion::error::Result;
use datafusion::{assert_batches_eq, prelude::*};
use datafusion_common::cast::as_float64_array;
use datafusion_expr::{
expr::{AggregateFunction, AggregateFunctionDefinition},
function::AccumulatorArgs,
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
};

/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user
/// defined aggregate function with a different expression which is defined in the `simplify` method.

#[derive(Debug, Clone)]
struct BetterAvgUdaf {
signature: Signature,
}

impl BetterAvgUdaf {
/// Create a new instance of the GeoMeanUdaf struct
fn new() -> Self {
Self {
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for BetterAvgUdaf {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"better_avg"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
unimplemented!("should not be invoked")
}

fn state_fields(
&self,
_name: &str,
_value_type: DataType,
_ordering_fields: Vec<arrow_schema::Field>,
) -> Result<Vec<arrow_schema::Field>> {
unimplemented!("should not be invoked")
}

fn groups_accumulator_supported(&self) -> bool {
true
}

fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
unimplemented!("should not get here");
}
// we override method, to return new expression which would substitute
// user defined function call
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
// as an example for this functionality we replace UDF function
// with build-in aggregate function to illustrate the use
let simplify = |aggregate_function: datafusion_expr::expr::AggregateFunction,
_: &dyn SimplifyInfo| {
Ok(Expr::AggregateFunction(AggregateFunction {
func_def: AggregateFunctionDefinition::BuiltIn(
// yes it is the same Avg, `BetterAvgUdaf` was just a
// marketing pitch :)
datafusion_expr::aggregate_function::AggregateFunction::Avg,
),
args: aggregate_function.args,
distinct: aggregate_function.distinct,
filter: aggregate_function.filter,
order_by: aggregate_function.order_by,
null_treatment: aggregate_function.null_treatment,
}))
};

Some(Box::new(simplify))
}
}

// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float32, false),
]));

// define data in two partitions
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
],
)?;
let batch2 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![16.0])),
Arc::new(Float32Array::from(vec![2.0])),
],
)?;

let ctx = SessionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
ctx.register_table("t", Arc::new(provider))?;
Ok(ctx)
}

#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context()?;

let better_avg = AggregateUDF::from(BetterAvgUdaf::new());
ctx.register_udaf(better_avg.clone());

let result = ctx
.sql("SELECT better_avg(a) FROM t group by b")
.await?
.collect()
.await?;

let expected = [
"+-----------------+",
"| better_avg(t.a) |",
"+-----------------+",
"| 7.5 |",
"+-----------------+",
];

assert_batches_eq!(expected, &result);

let df = ctx.table("t").await?;
let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?;

let results = df.collect().await?;
let result = as_float64_array(results[0].column(0))?;

assert!((result.value(0) - 7.5).abs() < f64::EPSILON);
println!("The average of [2,4,8,16] is {}", result.value(0));

Ok(())
}
13 changes: 13 additions & 0 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,16 @@ pub type PartitionEvaluatorFactory =
/// its state, given its return datatype.
pub type StateTypeFunction =
Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;

/// [crate::udaf::AggregateUDFImpl::simplify] simplifier closure
/// A closure with two arguments:
/// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked
/// * 'info': [crate::simplify::SimplifyInfo]
///
/// closure returns simplified [Expr] or an error.
pub type AggregateFunctionSimplification = Box<
dyn Fn(
crate::expr::AggregateFunction,
&dyn crate::simplify::SimplifyInfo,
) -> Result<Expr>,
>;
33 changes: 32 additions & 1 deletion datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

//! [`AggregateUDF`]: User Defined Aggregate Functions

use crate::function::AccumulatorArgs;
use crate::function::{AccumulatorArgs, AggregateFunctionSimplification};
use crate::groups_accumulator::GroupsAccumulator;
use crate::utils::format_state_name;
use crate::{Accumulator, Expr};
Expand Down Expand Up @@ -199,6 +199,12 @@ impl AggregateUDF {
pub fn coerce_types(&self, _args: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!("coerce_types not implemented for {:?} yet", self.name())
}
/// Do the function rewrite
///
/// See [`AggregateUDFImpl::simplify`] for more details.
pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
self.inner.simplify()
}
}

impl<F> From<F> for AggregateUDF
Expand Down Expand Up @@ -358,6 +364,31 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn aliases(&self) -> &[String] {
&[]
}

/// Optionally apply per-UDaF simplification / rewrite rules.
///
/// This can be used to apply function specific simplification rules during
/// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
/// implementation does nothing.
///
/// Note that DataFusion handles simplifying arguments and "constant
/// folding" (replacing a function call with constant arguments such as
/// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
/// optimizations manually for specific UDFs.
///
/// # Returns
///
/// [None] if simplify is not defined or,
///
/// Or, a closure with two arguments:
/// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked
/// * 'info': [crate::simplify::SimplifyInfo]
///
/// closure returns simplified [Expr] or an error.
///
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
None
}
}

/// AggregateUDF that adds an alias to the underlying function. It is better to
Expand Down
105 changes: 103 additions & 2 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
};
use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::{InList, InSubquery};
use datafusion_expr::expr::{AggregateFunctionDefinition, InList, InSubquery};
use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::{
and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility,
Expand Down Expand Up @@ -1382,6 +1382,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
}
}

Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
func_def: AggregateFunctionDefinition::UDF(ref udaf),
..
}) => match (udaf.simplify(), expr) {
(Some(simplify_function), Expr::AggregateFunction(af)) => {
Transformed::yes(simplify_function(af, info)?)
}
(_, expr) => Transformed::no(expr),
},

//
// Rules for Between
//
Expand Down Expand Up @@ -1748,7 +1758,9 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result<Expr> {
#[cfg(test)]
mod tests {
use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
use datafusion_expr::{interval_arithmetic::Interval, *};
use datafusion_expr::{
function::AggregateFunctionSimplification, interval_arithmetic::Interval, *,
};
use std::{
collections::HashMap,
ops::{BitAnd, BitOr, BitXor},
Expand Down Expand Up @@ -3698,4 +3710,93 @@ mod tests {
assert_eq!(expr, expected);
assert_eq!(num_iter, 2);
}
#[test]
fn test_simplify_udaf() {
let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify());
let aggregate_function_expr =
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
udaf.into(),
vec![],
false,
None,
None,
None,
));

let expected = col("result_column");
assert_eq!(simplify(aggregate_function_expr), expected);

let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify());
let aggregate_function_expr =
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
udaf.into(),
vec![],
false,
None,
None,
None,
));

let expected = aggregate_function_expr.clone();
assert_eq!(simplify(aggregate_function_expr), expected);
}

/// A Mock UDAF which defines `simplify` to be used in tests
/// related to UDAF simplification
#[derive(Debug, Clone)]
struct SimplifyMockUdaf {
simplify: bool,
}

impl SimplifyMockUdaf {
/// make simplify method return new expression
fn new_with_simplify() -> Self {
Self { simplify: true }
}
/// make simplify method return no change
fn new_without_simplify() -> Self {
Self { simplify: false }
}
}

impl AggregateUDFImpl for SimplifyMockUdaf {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"mock_simplify"
}

fn signature(&self) -> &Signature {
unimplemented!()
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
unimplemented!("not needed for tests")
}

fn accumulator(
&self,
_acc_args: function::AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
unimplemented!("not needed for tests")
}

fn groups_accumulator_supported(&self) -> bool {
unimplemented!("not needed for testing")
}

fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
unimplemented!("not needed for testing")
}

fn simplify(&self) -> Option<AggregateFunctionSimplification> {
if self.simplify {
Some(Box::new(|_, _| Ok(col("result_column"))))
} else {
None
}
}
}
}

0 comments on commit 230c68c

Please sign in to comment.