Skip to content

Commit

Permalink
fix: summarize raising error when a grouping col is all NA (or mostly…
Browse files Browse the repository at this point in the history
… NA) (#459)

* fix(pandas): summarize works with na group cols, preserves keys
* tests: correctly skip unneeded duckdb test
  • Loading branch information
machow committed Nov 16, 2022
1 parent 99127df commit b1e1768
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 12 deletions.
19 changes: 16 additions & 3 deletions siuba/dply/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _mutate_cols(__data, args, kwargs):


def _make_groupby_safe(gdf):
return gdf.obj.groupby(gdf.grouper, group_keys=False)
return gdf.obj.groupby(gdf.grouper, group_keys=False, dropna=False)


MSG_TYPE_ERROR = "The first argument to {func} must be one of: {types}"
Expand Down Expand Up @@ -363,9 +363,9 @@ def group_by(__data, *args, add = False, **kwargs):
# ensures group levels are recalculated if varname was in transmute
groupings[varname] = varname

return tmp_df.groupby(list(groupings.values()))
return tmp_df.groupby(list(groupings.values()), dropna=False, group_keys=True)

return tmp_df.groupby(by = by_vars)
return tmp_df.groupby(by = by_vars, dropna=False, group_keys=True)


@singledispatch2((pd.DataFrame, DataFrameGroupBy))
Expand Down Expand Up @@ -563,6 +563,19 @@ def summarize(__data, *args, **kwargs):

@summarize.register(DataFrameGroupBy)
def _summarize(__data, *args, **kwargs):
if __data.dropna or not __data.group_keys:
warnings.warn(
f"Grouped data passed to summarize must have dropna=False and group_keys=True."
" Regrouping with these arguments set."
)

if __data.grouper.dropna:
# will need to recalculate groupings, otherwise it ignores dropna
group_cols = [ping.name for ping in __data.grouper.groupings]
else:
group_cols = __data.grouper.groupings
__data = __data.obj.groupby(group_cols, dropna=False, group_keys=True)

df_summarize = summarize.registry[pd.DataFrame]

df = __data.apply(df_summarize, *args, **kwargs)
Expand Down
11 changes: 2 additions & 9 deletions siuba/tests/test_sql_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,10 @@ def test_raw_sql_mutate_grouped(backend, df):
)


@pytest.mark.skip_backend("snowflake") # supported by snowflake
@pytest.mark.skip_backend("snowflake", "duckdb") # they support this behavior
@backend_sql
def test_raw_sql_mutate_refer_previous_raise_dberror(backend, skip_backend, df):
# Note: unlikely will be able to support this case. Normally we analyze
if backend.name == "duckdb":
# duckdb dialect re-raises the engines exception, which is RuntimeError
# the expression to know whether we need to create a subquery.
import duckdb
exc = duckdb.BinderException
else:
exc = sqlalchemy.exc.DatabaseError
exc = sqlalchemy.exc.DatabaseError

with pytest.raises(exc):
assert_equal_query(
Expand Down
11 changes: 11 additions & 0 deletions siuba/tests/test_verb_mutate.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,17 @@ def test_mutate_reassign_all_cols_keeps_rowsize(dfs):
data_frame(a = [1,1,1], b = [2,2,2])
)


def test_mutate_grouped_pandas_no_dropna():
src = data_frame(x = [1, 2], g = [None, None])

assert_equal_query(
src,
group_by(_.g) >> mutate(res = _.x + 1),
data_frame(x = [1, 2], g = [None, None], res = [2, 3])
)


@backend_sql
def test_mutate_window_funcs(backend):
data = data_frame(idx = range(0, 4), x = range(1, 5), g = [1,1,2,2])
Expand Down
48 changes: 48 additions & 0 deletions siuba/tests/test_verb_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
https://github.com/tidyverse/dbplyr/blob/master/tests/testthat/test-verb-mutate.R
"""

import numpy as np

from siuba import _, mutate, select, group_by, summarize, filter, show_query, arrange
from siuba.dply.vector import row_number, n
Expand Down Expand Up @@ -47,6 +49,52 @@ def test_summarize_after_mutate_cuml_win(backend, df_float):
)


def test_summarize_keeps_na_grouping_cols(backend):
df = data_frame(x = [1, 2, 3], g = [None, None, None])
src = backend.load_df(df)

if backend.name == "pandas":
missing = np.nan
else:
missing = None

assert_equal_query(
src,
group_by(_.g) >> summarize(res = _.x.min()),
data_frame(g = [missing], res = [1])
)


def test_summarize_regroups_group_keys():
df = data_frame(x = [1, 2, 3], g = [None, None, None])

# bad group_keys choice
g_df = df.groupby("g", group_keys=False, dropna=False)

with pytest.warns(UserWarning, match="group_keys=True"):

assert_equal_query(
g_df,
summarize(res = _.x.min()),
data_frame(g = [np.nan], res = [1])
)


def test_summarize_regroups_dropna():
df = data_frame(x = [1, 2, 3], g = [None, None, None])

# bad dropna choice
g_df = df.groupby("g", group_keys=True, dropna=True)

with pytest.warns(UserWarning, match="dropna=False"):

assert_equal_query(
g_df,
summarize(res = _.x.min()),
data_frame(g = [np.nan], res = [1])
)


@backend_sql
def test_summarize_keeps_group_vars(backend, gdf):
q = gdf >> summarize(n = n(_))
Expand Down

0 comments on commit b1e1768

Please sign in to comment.