Skip to content

Commit

Permalink
fix: warning emits an error (apache#28524)
Browse files Browse the repository at this point in the history
  • Loading branch information
eschutho authored Nov 1, 2024
1 parent 3ec3f0a commit d466383
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 1 deletion.
7 changes: 6 additions & 1 deletion superset/utils/pandas_postprocessing/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,10 @@ def compare( # pylint: disable=too-many-arguments
df = pd.concat([df, diff_df], axis=1)

if drop_original_columns:
df = df.drop(source_columns + compare_columns, axis=1)
level = (
0
if isinstance(df.columns, pd.MultiIndex) and df.columns.nlevels > 1
else None
)
df = df.drop(source_columns + compare_columns, axis=1, level=level)
return df
67 changes: 67 additions & 0 deletions tests/unit_tests/pandas_postprocessing/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import io
import sys

import pandas as pd

from superset.constants import PandasPostprocessingCompare as PPC
Expand Down Expand Up @@ -179,6 +182,70 @@ def test_compare_multi_index_column():
)


def test_compare_multi_index_column_non_lex_sorted():
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
index.name = "__timestamp"

iterables = [["m1", "m2"], ["a", "b"], ["x", "y"]]
columns = pd.MultiIndex.from_product(iterables, names=[None, "level1", "level2"])

df = pd.DataFrame(index=index, columns=columns, data=1)

# Define a non-lexicographical column order
# arrange them as m1, m2 instead of m2, m1
new_columns_order = [
("m1", "a", "x"),
("m1", "a", "y"),
("m1", "b", "x"),
("m1", "b", "y"),
("m2", "a", "x"),
("m2", "a", "y"),
("m2", "b", "x"),
("m2", "b", "y"),
]

df.columns = pd.MultiIndex.from_tuples(
new_columns_order, names=["level1", "level2", None]
)

# to capture stderr
stderr = sys.stderr
sys.stderr = io.StringIO()

try:
post_df = pp.compare(
df,
source_columns=["m1"],
compare_columns=["m2"],
compare_type=PPC.DIFF,
drop_original_columns=True,
)
assert sys.stderr.getvalue() == ""
finally:
sys.stderr = stderr

flat_df = pp.flatten(post_df)
"""
__timestamp difference__m1__m2, a, x difference__m1__m2, a, y difference__m1__m2, b, x difference__m1__m2, b, y
0 2021-01-01 0 0 0 0
1 2021-01-02 0 0 0 0
2 2021-01-03 0 0 0 0
"""
assert flat_df.equals(
pd.DataFrame(
data={
"__timestamp": pd.to_datetime(
["2021-01-01", "2021-01-02", "2021-01-03"]
),
"difference__m1__m2, a, x": [0, 0, 0],
"difference__m1__m2, a, y": [0, 0, 0],
"difference__m1__m2, b, x": [0, 0, 0],
"difference__m1__m2, b, y": [0, 0, 0],
}
)
)


def test_compare_after_pivot():
pivot_df = pp.pivot(
df=multiple_metrics_df,
Expand Down

0 comments on commit d466383

Please sign in to comment.