Skip to content

Commit

Permalink
add simplify method for aggregate function
Browse files Browse the repository at this point in the history
  • Loading branch information
milenkovicm committed May 7, 2024
1 parent 9fd697c commit 62d381e
Show file tree
Hide file tree
Showing 3 changed files with 366 additions and 1 deletion.
185 changes: 185 additions & 0 deletions datafusion-examples/examples/simplify_udaf_expression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow_schema::{Field, Schema};
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_expr::simplify::SimplifyInfo;

use std::{any::Any, sync::Arc};

use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch};
use datafusion::error::Result;
use datafusion::{assert_batches_eq, prelude::*};
use datafusion_common::cast::as_float64_array;
use datafusion_expr::{
expr::{AggregateFunction, AggregateFunctionDefinition},
function::AccumulatorArgs,
simplify::ExprSimplifyResult,
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
};

/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user
/// defined aggregate function with a different expression which is defined in the `simplify` method.

#[derive(Debug, Clone)]
struct BetterAvgUdaf {
signature: Signature,
}

impl BetterAvgUdaf {
/// Create a new instance of the GeoMeanUdaf struct
fn new() -> Self {
Self {
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for BetterAvgUdaf {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"better_avg"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
unimplemented!("should not be invoked")
}

fn state_fields(
&self,
_name: &str,
_value_type: DataType,
_ordering_fields: Vec<arrow_schema::Field>,
) -> Result<Vec<arrow_schema::Field>> {
unimplemented!("should not be invoked")
}

fn groups_accumulator_supported(&self) -> bool {
true
}

fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
unimplemented!("should not get here");
}
// we override method, to return new expression which would substitute
// user defined function call
fn simplify(
&self,
args: Vec<Expr>,
distinct: &bool,
filter: &Option<Box<Expr>>,
order_by: &Option<Vec<Expr>>,
null_treatment: &Option<datafusion_sql::sqlparser::ast::NullTreatment>,
_info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
// as an example for this functionality we replace UDF function
// with build-in aggregate function to illustrate the use
let expr = Expr::AggregateFunction(AggregateFunction {
func_def: AggregateFunctionDefinition::BuiltIn(
// yes it is the same Avg, `BetterAvgUdaf` was just a
// marketing pitch :)
datafusion_expr::aggregate_function::AggregateFunction::Avg,
),
args,
distinct: *distinct,
filter: filter.clone(),
order_by: order_by.clone(),
null_treatment: *null_treatment,
});

Ok(ExprSimplifyResult::Simplified(expr))
}
}

// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float32, false),
]));

// define data in two partitions
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
],
)?;
let batch2 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![16.0])),
Arc::new(Float32Array::from(vec![2.0])),
],
)?;

let ctx = SessionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
ctx.register_table("t", Arc::new(provider))?;
Ok(ctx)
}

#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context()?;

let better_avg = AggregateUDF::from(BetterAvgUdaf::new());
ctx.register_udaf(better_avg.clone());

let result = ctx
.sql("SELECT better_avg(a) FROM t group by b")
.await?
.collect()
.await?;

let expected = [
"+-----------------+",
"| better_avg(t.a) |",
"+-----------------+",
"| 7.5 |",
"+-----------------+",
];

assert_batches_eq!(expected, &result);

let df = ctx.table("t").await?;
let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?;

let results = df.collect().await?;
let result = as_float64_array(results[0].column(0))?;

assert!((result.value(0) - 7.5).abs() < f64::EPSILON);
println!("The average of [2,4,8,16] is {}", result.value(0));

Ok(())
}
48 changes: 48 additions & 0 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

use crate::function::AccumulatorArgs;
use crate::groups_accumulator::GroupsAccumulator;
use crate::simplify::{ExprSimplifyResult, SimplifyInfo};
use crate::utils::format_state_name;
use crate::{Accumulator, Expr};
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
use arrow::datatypes::{DataType, Field};
use datafusion_common::{not_impl_err, Result};
use sqlparser::ast::NullTreatment;
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
Expand Down Expand Up @@ -195,6 +197,21 @@ impl AggregateUDF {
pub fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
self.inner.create_groups_accumulator()
}
/// Do the function rewrite
///
/// See [`AggregateUDFImpl::simplify`] for more details.
pub fn simplify(
&self,
args: Vec<Expr>,
distinct: &bool,
filter: &Option<Box<Expr>>,
order_by: &Option<Vec<Expr>>,
null_treatment: &Option<NullTreatment>,
info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
self.inner
.simplify(args, distinct, filter, order_by, null_treatment, info)
}
}

impl<F> From<F> for AggregateUDF
Expand Down Expand Up @@ -354,6 +371,37 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn aliases(&self) -> &[String] {
&[]
}

/// Optionally apply per-UDF simplification / rewrite rules.
///
/// This can be used to apply function specific simplification rules during
/// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
/// implementation does nothing.
///
/// Note that DataFusion handles simplifying arguments and "constant
/// folding" (replacing a function call with constant arguments such as
/// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
/// optimizations manually for specific UDFs.
///
/// # Arguments
/// * 'args': The arguments of the function
/// * 'schema': The schema of the function
///
/// # Returns
/// [`ExprSimplifyResult`] indicating the result of the simplification NOTE
/// if the function cannot be simplified, the arguments *MUST* be returned
/// unmodified
fn simplify(
&self,
args: Vec<Expr>,
_distinct: &bool,
_filter: &Option<Box<Expr>>,
_order_by: &Option<Vec<Expr>>,
_null_treatment: &Option<NullTreatment>,
_info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
Ok(ExprSimplifyResult::Original(args))
}
}

/// AggregateUDF that adds an alias to the underlying function. It is better to
Expand Down

0 comments on commit 62d381e

Please sign in to comment.