Skip to content

Commit

Permalink
chore(test): consolidate ifelse() and case() tests (#9560)
Browse files Browse the repository at this point in the history
  • Loading branch information
NickCrews authored Jul 15, 2024
1 parent 5b396e0 commit 9b797f6
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 165 deletions.
58 changes: 0 additions & 58 deletions ibis/backends/dask/tests/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,64 +773,6 @@ def q_fun(x, quantile):
tm.assert_series_equal(result, expected, check_index=False)


def test_searched_case_scalar(client):
expr = ibis.case().when(True, 1).when(False, 2).end()
result = client.execute(expr)
expected = np.int8(1)
assert result == expected


def test_searched_case_column(batting, batting_pandas_df):
t = batting
df = batting_pandas_df
expr = (
ibis.case()
.when(t.RBI < 5, "really bad team")
.when(t.teamID == "PH1", "ph1 team")
.else_(t.teamID)
.end()
)
result = expr.execute()
expected = pd.Series(
np.select(
[df.RBI < 5, df.teamID == "PH1"],
["really bad team", "ph1 team"],
df.teamID,
)
)
tm.assert_series_equal(result, expected, check_names=False)


def test_simple_case_scalar(client):
x = ibis.literal(2)
expr = x.case().when(2, x - 1).when(3, x + 1).when(4, x + 2).end()
result = client.execute(expr)
expected = np.int8(1)
assert result == expected


def test_simple_case_column(batting, batting_pandas_df):
t = batting
df = batting_pandas_df
expr = (
t.RBI.case()
.when(5, "five")
.when(4, "four")
.when(3, "three")
.else_("could be good?")
.end()
)
result = expr.execute()
expected = pd.Series(
np.select(
[df.RBI == 5, df.RBI == 4, df.RBI == 3],
["five", "four", "three"],
"could be good?",
)
)
tm.assert_series_equal(result, expected, check_names=False)


def test_table_distinct(t, df):
expr = t[["dup_strings"]].distinct()
result = expr.compile()
Expand Down
58 changes: 0 additions & 58 deletions ibis/backends/pandas/tests/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,64 +683,6 @@ def test_summary_non_numeric(batting, batting_df):
assert dict(result.iloc[0]) == expected


def test_searched_case_scalar(client):
expr = ibis.case().when(True, 1).when(False, 2).end()
result = client.execute(expr)
expected = np.int8(1)
assert result == expected


def test_searched_case_column(batting, batting_df):
t = batting
df = batting_df
expr = (
ibis.case()
.when(t.RBI < 5, "really bad team")
.when(t.teamID == "PH1", "ph1 team")
.else_(t.teamID)
.end()
)
result = expr.execute()
expected = pd.Series(
np.select(
[df.RBI < 5, df.teamID == "PH1"],
["really bad team", "ph1 team"],
df.teamID,
)
)
tm.assert_series_equal(result, expected)


def test_simple_case_scalar(client):
x = ibis.literal(2)
expr = x.case().when(2, x - 1).when(3, x + 1).when(4, x + 2).end()
result = client.execute(expr)
expected = np.int8(1)
assert result == expected


def test_simple_case_column(batting, batting_df):
t = batting
df = batting_df
expr = (
t.RBI.case()
.when(5, "five")
.when(4, "four")
.when(3, "three")
.else_("could be good?")
.end()
)
result = expr.execute()
expected = pd.Series(
np.select(
[df.RBI == 5, df.RBI == 4, df.RBI == 3],
["five", "four", "three"],
"could be good?",
)
)
tm.assert_series_equal(result, expected)


def test_non_range_index():
def do_replace(col):
return col.cases(
Expand Down
149 changes: 149 additions & 0 deletions ibis/backends/tests/test_conditionals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from __future__ import annotations

from collections import Counter

import numpy as np
import pandas as pd
import pytest

import ibis


def test_ifelse_select(backend, alltypes, df):
table = alltypes
table = table.select(
[
"int_col",
(
ibis.ifelse(table["int_col"] == 0, 42, -1)
.cast("int64")
.name("where_col")
),
]
)

result = table.execute()

expected = df.loc[:, ["int_col"]].copy()

expected["where_col"] = -1
expected.loc[expected["int_col"] == 0, "where_col"] = 42

backend.assert_frame_equal(result, expected)


def test_ifelse_column(backend, alltypes, df):
expr = ibis.ifelse(alltypes["int_col"] == 0, 42, -1).cast("int64").name("where_col")
result = expr.execute()

expected = pd.Series(
np.where(df.int_col == 0, 42, -1),
name="where_col",
dtype="int64",
)

backend.assert_series_equal(result, expected)


def test_substitute(backend):
val = "400"
t = backend.functional_alltypes
expr = (
t.string_col.nullif("1")
.substitute({None: val})
.name("subs")
.value_counts()
.filter(lambda t: t.subs == val)
)
assert expr["subs_count"].execute()[0] == t.count().execute() // 10


@pytest.mark.parametrize(
"inp, exp",
[
pytest.param(
lambda: ibis.literal(1)
.case()
.when(1, "one")
.when(2, "two")
.else_("other")
.end(),
"one",
id="one_kwarg",
),
pytest.param(
lambda: ibis.literal(5).case().when(1, "one").when(2, "two").end(),
None,
id="fallthrough",
),
],
)
def test_value_cases_scalar(con, inp, exp):
result = con.execute(inp())
if exp is None:
assert pd.isna(result)
else:
assert result == exp


@pytest.mark.broken(
"exasol",
reason="the int64 RBI column is .to_pandas()ed to an object column, which is incomparable to ints",
raises=AssertionError,
)
def test_value_cases_column(batting):
df = batting.to_pandas()
expr = (
batting.RBI.case()
.when(5, "five")
.when(4, "four")
.when(3, "three")
.else_("could be good?")
.end()
)
result = expr.execute()
expected = np.select(
[df.RBI == 5, df.RBI == 4, df.RBI == 3],
["five", "four", "three"],
"could be good?",
)

assert Counter(result) == Counter(expected)


def test_ibis_cases_scalar():
expr = ibis.literal(5).case().when(5, "five").when(4, "four").end()
result = expr.execute()
assert result == "five"


@pytest.mark.broken(
["sqlite", "exasol"],
reason="the int64 RBI column is .to_pandas()ed to an object column, which is incomparable to 5",
raises=TypeError,
)
def test_ibis_cases_column(batting):
t = batting
df = batting.to_pandas()
expr = (
ibis.case()
.when(t.RBI < 5, "really bad team")
.when(t.teamID == "PH1", "ph1 team")
.else_(t.teamID)
.end()
)
result = expr.execute()
expected = np.select(
[df.RBI < 5, df.teamID == "PH1"],
["really bad team", "ph1 team"],
df.teamID,
)

assert Counter(result) == Counter(expected)


@pytest.mark.broken("clickhouse", reason="special case this and returns 'oops'")
def test_value_cases_null(con):
"""CASE x WHEN NULL never gets hit"""
e = ibis.literal(5).nullif(5).case().when(None, "oops").else_("expected").end()
assert con.execute(e) == "expected"
49 changes: 0 additions & 49 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,42 +1088,6 @@ def test_zero_ifnull_column(backend, alltypes, df):
backend.assert_series_equal(result, expected)


def test_ifelse_select(backend, alltypes, df):
table = alltypes
table = table.select(
[
"int_col",
(
ibis.ifelse(table["int_col"] == 0, 42, -1)
.cast("int64")
.name("where_col")
),
]
)

result = table.execute()

expected = df.loc[:, ["int_col"]].copy()

expected["where_col"] = -1
expected.loc[expected["int_col"] == 0, "where_col"] = 42

backend.assert_frame_equal(result, expected)


def test_ifelse_column(backend, alltypes, df):
expr = ibis.ifelse(alltypes["int_col"] == 0, 42, -1).cast("int64").name("where_col")
result = expr.execute()

expected = pd.Series(
np.where(df.int_col == 0, 42, -1),
name="where_col",
dtype="int64",
)

backend.assert_series_equal(result, expected)


def test_select_filter(backend, alltypes, df):
t = alltypes

Expand Down Expand Up @@ -2326,19 +2290,6 @@ def test_sample_with_seed(backend):
backend.assert_frame_equal(df1, df2)


def test_substitute(backend):
val = "400"
t = backend.functional_alltypes
expr = (
t.string_col.nullif("1")
.substitute({None: val})
.name("subs")
.value_counts()
.filter(lambda t: t.subs == val)
)
assert expr["subs_count"].execute()[0] == t.count().execute() // 10


@pytest.mark.notimpl(
["dask", "pandas", "polars"], raises=NotImplementedError, reason="not a SQL backend"
)
Expand Down

0 comments on commit 9b797f6

Please sign in to comment.