Skip to content

Commit

Permalink
Merge pull request #374 from machow/fix-join-on-arg
Browse files Browse the repository at this point in the history
Fix semi_join inferring on arg
  • Loading branch information
machow committed Jan 13, 2022
2 parents b76ba35 + f03a8f4 commit f5b36f8
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 6 deletions.
11 changes: 9 additions & 2 deletions siuba/dply/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,9 +1110,16 @@ def semi_join(left, right = None, on = None):
on_cols, right_on = map(list, zip(*on.items()))
right = right[right_on].rename(dict(zip(right_on, on_cols)))
elif on is None:
on_cols = set(left.columns).intersection(set(right.columns))
warnings.warn(
"No on column passed to join. "
"Inferring join columns instead using shared column names."
)

on_cols = list(set(left.columns).intersection(set(right.columns)))
if not len(on_cols):
raise Exception("No joining column specified, and no shared column names")
raise Exception("No join column specified, and no shared column names")

warnings.warn("Detected shared columns: %s" % on_cols)
elif isinstance(on, str):
on_cols = [on]
else:
Expand Down
31 changes: 27 additions & 4 deletions siuba/sql/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import warnings

from siuba.dply.verbs import (
singledispatch2,
Expand Down Expand Up @@ -936,7 +937,7 @@ def _semi_join(left, right = None, on = None, *args, sql_on = None):
right_sel = right.last_op.alias()

# handle arguments ----
on = _validate_join_arg_on(on, sql_on)
on = _validate_join_arg_on(on, sql_on, left_sel, right_sel)

# create join conditions ----
bool_clause = _create_join_conds(left_sel, right_sel, on)
Expand All @@ -962,7 +963,7 @@ def _anti_join(left, right = None, on = None, *args, sql_on = None):
right_sel = right.last_op.alias()

# handle arguments ----
on = _validate_join_arg_on(on, sql_on)
on = _validate_join_arg_on(on, sql_on, left, right)

# create join conditions ----
bool_clause = _create_join_conds(left_sel, right_sel, on)
Expand All @@ -981,7 +982,7 @@ def _raise_if_args(args):
if len(args):
raise NotImplemented("*args is reserved for future arguments (e.g. suffix)")

def _validate_join_arg_on(on, sql_on = None):
def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None):
# handle sql on case
if sql_on is not None:
if on is not None:
Expand All @@ -991,12 +992,34 @@ def _validate_join_arg_on(on, sql_on = None):

# handle general cases
if on is None:
raise NotImplementedError("on arg currently cannot be None (default) for SQL")
# TODO: currently, we check for lhs and rhs tables to indicate whether
# a verb supports inferring columns. Otherwise, raise an error.
if lhs is not None and rhs is not None:
# TODO: consolidate with duplicate logic in pandas verb code
warnings.warn(
"No on column passed to join. "
"Inferring join columns instead using shared column names."
)

on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys())))

if not on_cols:
raise ValueError(
"No join column specified, or shared column names in join."
)

# trivial dict mapping shared names to themselves
warnings.warn("Detected shared columns: %s" % on_cols)
on = dict(zip(on_cols, on_cols))

else:
raise NotImplementedError("on arg currently cannot be None (default) for SQL")
elif isinstance(on, str):
on = {on: on}
elif isinstance(on, (list, tuple)):
on = dict(zip(on, on))


if not isinstance(on, Mapping):
raise TypeError("on must be a Mapping (e.g. dict)")

Expand Down
21 changes: 21 additions & 0 deletions siuba/tests/test_verb_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,27 @@ def test_semi_join_no_cross(backend, df1, df2):
DF1.iloc[:1,]
)

def test_semi_join_no_on_arg(backend, df1):
df_ii = backend.load_df(data_frame(ii = [1,1]))

with pytest.warns(UserWarning) as record:
assert_equal_query(
df1,
semi_join(_, df_ii),
DF1.iloc[:1,]
)

assert "No on column passed to join." in record[0].message.args[0]
assert "['ii']" in record[1].message.args[0]

def test_semi_join_no_on_arg_fail(backend, df1):
df_ii = backend.load_df(data_frame(ZZ = [1,1]))

with pytest.raises(Exception) as excinfo:
collect(semi_join(df1, df_ii))

assert "No join column specified" in str(excinfo.value)


def test_basic_anti_join_on_map(backend, df1, df2):
assert_frame_sort_equal(
Expand Down

0 comments on commit f5b36f8

Please sign in to comment.