Skip to content

Commit

Permalink
feat(function): Add mode Aggregate function. (#16627)
Browse files Browse the repository at this point in the history
* add mode()

* fix typo

* fix

---------

Co-authored-by: sundyli <[email protected]>
  • Loading branch information
Freejww and sundy-li authored Oct 17, 2024
1 parent efa1e96 commit 16e1e53
Show file tree
Hide file tree
Showing 7 changed files with 295 additions and 1 deletion.
167 changes: 167 additions & 0 deletions src/query/functions/src/aggregates/aggregate_mode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// Copyright 2021 Datafuse Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::hash::Hash;
use std::ops::AddAssign;
use std::sync::Arc;

use borsh::BorshDeserialize;
use borsh::BorshSerialize;
use databend_common_exception::Result;
use databend_common_expression::types::*;
use databend_common_expression::with_number_mapped_type;
use databend_common_expression::AggregateFunctionRef;
use databend_common_expression::Scalar;

use super::FunctionData;
use super::UnaryState;
use crate::aggregates::aggregate_function_factory::AggregateFunctionDescription;
use crate::aggregates::assert_unary_arguments;
use crate::aggregates::AggregateUnaryFunction;

#[derive(BorshSerialize, BorshDeserialize)]
pub struct ModeState<T>
where
T: ValueType,
T::Scalar: Ord + Hash + BorshSerialize + BorshDeserialize,
{
pub frequency_map: HashMap<T::Scalar, u64>,
}

impl<T> Default for ModeState<T>
where
T: ValueType,
T::Scalar: Ord + Hash + BorshSerialize + BorshDeserialize,
{
fn default() -> Self {
ModeState::<T> {
frequency_map: HashMap::new(),
}
}
}

impl<T> UnaryState<T, T> for ModeState<T>
where
T: ValueType + Sync + Send,
T::Scalar: Ord + Hash + Sync + Send + BorshSerialize + BorshDeserialize,
{
fn add(
&mut self,
other: T::ScalarRef<'_>,
_function_data: Option<&dyn FunctionData>,
) -> Result<()> {
let other = T::to_owned_scalar(other);
match self.frequency_map.entry(other) {
Entry::Occupied(o) => *o.into_mut() += 1,
Entry::Vacant(v) => {
v.insert(1);
}
};

Ok(())
}

fn merge(&mut self, rhs: &Self) -> Result<()> {
for (key, value) in rhs.frequency_map.iter() {
match self.frequency_map.get_mut(key) {
Some(entry) => entry.add_assign(value),
None => {
self.frequency_map.insert(key.clone(), *value);
}
}
}

Ok(())
}

fn merge_result(
&mut self,
builder: &mut T::ColumnBuilder,
_function_data: Option<&dyn FunctionData>,
) -> Result<()> {
if self.frequency_map.is_empty() {
T::push_default(builder);
} else {
let (key, _) = self
.frequency_map
.iter()
.max_by_key(|&(_, value)| value)
.unwrap();
T::push_item(builder, T::to_scalar_ref(key));
}

Ok(())
}
}

pub fn try_create_aggregate_mode_function(
display_name: &str,
params: Vec<Scalar>,
arguments: Vec<DataType>,
) -> Result<AggregateFunctionRef> {
assert_unary_arguments(display_name, arguments.len())?;

let data_type = arguments[0].clone();
with_number_mapped_type!(|NUM| match &data_type {
DataType::Number(NumberDataType::NUM) => {
let func = AggregateUnaryFunction::<
ModeState<NumberType<NUM>>,
NumberType<NUM>,
NumberType<NUM>,
>::try_create(
display_name, data_type.clone(), params, data_type.clone()
)
.with_need_drop(true);
Ok(Arc::new(func))
}
DataType::Decimal(DecimalDataType::Decimal128(_)) => {
let func = AggregateUnaryFunction::<
ModeState<Decimal128Type>,
Decimal128Type,
Decimal128Type,
>::try_create(
display_name, data_type.clone(), params, data_type.clone()
)
.with_need_drop(true);
Ok(Arc::new(func))
}
DataType::Decimal(DecimalDataType::Decimal256(_)) => {
let func = AggregateUnaryFunction::<
ModeState<Decimal256Type>,
Decimal256Type,
Decimal256Type,
>::try_create(
display_name, data_type.clone(), params, data_type.clone()
)
.with_need_drop(true);
Ok(Arc::new(func))
}
_ => {
let func = AggregateUnaryFunction::<ModeState<AnyType>, AnyType, AnyType>::try_create(
display_name,
data_type.clone(),
params,
data_type.clone(),
)
.with_need_drop(true);
Ok(Arc::new(func))
}
})
}

pub fn aggregate_mode_function_desc() -> AggregateFunctionDescription {
AggregateFunctionDescription::creator(Box::new(try_create_aggregate_mode_function))
}
3 changes: 3 additions & 0 deletions src/query/functions/src/aggregates/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use super::aggregate_covariance::aggregate_covariance_sample_desc;
use super::aggregate_min_max_any::aggregate_any_function_desc;
use super::aggregate_min_max_any::aggregate_max_function_desc;
use super::aggregate_min_max_any::aggregate_min_function_desc;
use super::aggregate_mode::aggregate_mode_function_desc;
use super::aggregate_stddev::aggregate_stddev_pop_function_desc;
use super::aggregate_stddev::aggregate_stddev_samp_function_desc;
use super::aggregate_window_funnel::aggregate_window_funnel_function_desc;
Expand Down Expand Up @@ -141,6 +142,8 @@ impl Aggregators {
);

factory.register("histogram", aggregate_histogram_function_desc());

factory.register("mode", aggregate_mode_function_desc());
}

pub fn register_combinator(factory: &mut AggregateFunctionFactory) {
Expand Down
2 changes: 2 additions & 0 deletions src/query/functions/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ mod aggregate_json_array_agg;
mod aggregate_json_object_agg;
mod aggregate_kurtosis;
mod aggregate_min_max_any;
mod aggregate_mode;
mod aggregate_null_result;
mod aggregate_quantile_cont;
mod aggregate_quantile_disc;
Expand Down Expand Up @@ -64,6 +65,7 @@ pub use aggregate_json_array_agg::*;
pub use aggregate_json_object_agg::*;
pub use aggregate_kurtosis::*;
pub use aggregate_min_max_any::*;
pub use aggregate_mode::*;
pub use aggregate_null_result::AggregateNullResultFunction;
pub use aggregate_quantile_cont::*;
pub use aggregate_quantile_disc::*;
Expand Down
10 changes: 10 additions & 0 deletions src/query/functions/tests/it/aggregates/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ fn test_agg() {
test_agg_histogram(file, eval_aggr);
test_agg_json_array_agg(file, eval_aggr);
test_agg_json_object_agg(file, eval_aggr);
test_agg_mode(file, eval_aggr);
}

#[test]
Expand Down Expand Up @@ -111,6 +112,7 @@ fn test_agg_group_by() {
test_agg_group_array_moving_sum(file, eval_aggr);
test_agg_json_array_agg(file, eval_aggr);
test_agg_json_object_agg(file, eval_aggr);
test_agg_mode(file, simulate_two_groups_group_by);
}

fn gen_bitmap_data() -> Column {
Expand Down Expand Up @@ -139,6 +141,7 @@ fn get_example() -> Vec<(&'static str, Column)> {
("a", Int64Type::from_data(vec![4i64, 3, 2, 1])),
("b", UInt64Type::from_data(vec![1u64, 2, 3, 4])),
("c", UInt64Type::from_data(vec![1u64, 2, 1, 3])),
("d", UInt64Type::from_data(vec![1u64, 1, 1, 1])),
(
"x_null",
UInt64Type::from_data_with_validity(vec![1u64, 2, 3, 4], vec![
Expand Down Expand Up @@ -882,3 +885,10 @@ fn test_agg_json_object_agg(file: &mut impl Write, simulator: impl AggregationSi
simulator,
);
}

fn test_agg_mode(file: &mut impl Write, simulator: impl AggregationSimulator) {
run_agg_ast(file, "mode(1)", get_example().as_slice(), simulator);
run_agg_ast(file, "mode(NULL)", get_example().as_slice(), simulator);
run_agg_ast(file, "mode(d)", get_example().as_slice(), simulator);
run_agg_ast(file, "mode(all_null)", get_example().as_slice(), simulator);
}
40 changes: 40 additions & 0 deletions src/query/functions/tests/it/aggregates/testdata/agg.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1492,3 +1492,43 @@ evaluation (internal):
+--------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+


ast: mode(1)
evaluation (internal):
+--------+---------------------------------------------------------------+
| Column | Data |
+--------+---------------------------------------------------------------+
| a | Int64([4, 3, 2, 1]) |
| Output | NullableColumn { column: UInt8([1]), validity: [0b_______1] } |
+--------+---------------------------------------------------------------+


ast: mode(NULL)
evaluation (internal):
+--------+---------------------+
| Column | Data |
+--------+---------------------+
| a | Int64([4, 3, 2, 1]) |
| Output | Null { len: 1 } |
+--------+---------------------+


ast: mode(d)
evaluation (internal):
+--------+----------------------------------------------------------------+
| Column | Data |
+--------+----------------------------------------------------------------+
| d | UInt64([1, 1, 1, 1]) |
| Output | NullableColumn { column: UInt64([1]), validity: [0b_______1] } |
+--------+----------------------------------------------------------------+


ast: mode(all_null)
evaluation (internal):
+----------+-------------------------------------------------------------------------+
| Column | Data |
+----------+-------------------------------------------------------------------------+
| all_null | NullableColumn { column: UInt64([1, 2, 3, 4]), validity: [0b____0000] } |
| Output | NullableColumn { column: UInt64([0]), validity: [0b_______0] } |
+----------+-------------------------------------------------------------------------+


40 changes: 40 additions & 0 deletions src/query/functions/tests/it/aggregates/testdata/agg_group_by.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1430,3 +1430,43 @@ evaluation (internal):
+--------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+


ast: mode(1)
evaluation (internal):
+--------+------------------------------------------------------------------+
| Column | Data |
+--------+------------------------------------------------------------------+
| a | Int64([4, 3, 2, 1]) |
| Output | NullableColumn { column: UInt8([1, 1]), validity: [0b______11] } |
+--------+------------------------------------------------------------------+


ast: mode(NULL)
evaluation (internal):
+--------+---------------------+
| Column | Data |
+--------+---------------------+
| a | Int64([4, 3, 2, 1]) |
| Output | Null { len: 2 } |
+--------+---------------------+


ast: mode(d)
evaluation (internal):
+--------+-------------------------------------------------------------------+
| Column | Data |
+--------+-------------------------------------------------------------------+
| d | UInt64([1, 1, 1, 1]) |
| Output | NullableColumn { column: UInt64([1, 1]), validity: [0b______11] } |
+--------+-------------------------------------------------------------------+


ast: mode(all_null)
evaluation (internal):
+----------+-------------------------------------------------------------------------+
| Column | Data |
+----------+-------------------------------------------------------------------------+
| all_null | NullableColumn { column: UInt64([1, 2, 3, 4]), validity: [0b____0000] } |
| Output | NullableColumn { column: UInt64([0, 0]), validity: [0b______00] } |
+----------+-------------------------------------------------------------------------+


Original file line number Diff line number Diff line change
Expand Up @@ -418,5 +418,37 @@ statement ok
DROP TABLE d

statement ok
DROP DATABASE db1
create or replace table aggr(k int, v decimal(10,2));

query I
select mode(v) from aggr;
----
NULL

statement ok
insert into aggr (k, v) values
(1, 10),
(1, 10),
(1, 10),
(2, 20),
(2, 20),
(2, 21),
(3, null);

query I
select mode(v) from aggr;
----
10.00

query II
select k, mode(v) from aggr group by k order by k;
----
1 10.00
2 20.00
3 NULL

statement ok
DROP TABLE aggr

statement ok
DROP DATABASE db1

0 comments on commit 16e1e53

Please sign in to comment.