Skip to content

Commit 18df89e

Browse files
authored
GH1169 Improve parameter types for DataFrame.pct_change and Series.pct_change (#1194)
* type pct_change kwargs according to shift params * remove duplicate params and add defaults * remove typeddict * address comments * add tests, fix axis argument type * fix shift/pct_change param types * fix return type of series pct_change, add test for series.pct_change
1 parent 75273f1 commit 18df89e

File tree

4 files changed

+45
-18
lines changed

4 files changed

+45
-18
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ from pandas._libs.lib import NoDefault
6767
from pandas._libs.missing import NAType
6868
from pandas._libs.tslibs import BaseOffset
6969
from pandas._libs.tslibs.nattype import NaTType
70+
from pandas._libs.tslibs.offsets import DateOffset
7071
from pandas._typing import (
7172
S1,
7273
AggFuncTypeBase,
@@ -87,7 +88,6 @@ from pandas._typing import (
8788
FilePath,
8889
FillnaOptions,
8990
FormattersType,
90-
Frequency,
9191
GroupByObjectNonScalar,
9292
HashableT,
9393
HashableT1,
@@ -830,10 +830,10 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
830830
) -> Self: ...
831831
def shift(
832832
self,
833-
periods: int = ...,
834-
freq: Frequency | dt.timedelta | None = ...,
833+
periods: int | Sequence[int] = ...,
834+
freq: DateOffset | dt.timedelta | _str | None = ...,
835835
axis: Axis = ...,
836-
fill_value: Hashable | None = ...,
836+
fill_value: Scalar | NAType | None = ...,
837837
) -> Self: ...
838838
@overload
839839
def set_index(
@@ -1989,9 +1989,10 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
19891989
self,
19901990
periods: int = ...,
19911991
fill_method: None = ...,
1992-
limit: int | None = ...,
1993-
freq=...,
1994-
**kwargs: Any, # TODO: make more precise https://github.com/pandas-dev/pandas-stubs/issues/1169
1992+
freq: DateOffset | dt.timedelta | _str | None = ...,
1993+
*,
1994+
axis: Axis = ...,
1995+
fill_value: Scalar | NAType | None = ...,
19951996
) -> Self: ...
19961997
def pop(self, item: _str) -> Series: ...
19971998
def pow(

pandas-stubs/core/series.pyi

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ from pandas._libs.lib import NoDefault
102102
from pandas._libs.missing import NAType
103103
from pandas._libs.tslibs import BaseOffset
104104
from pandas._libs.tslibs.nattype import NaTType
105+
from pandas._libs.tslibs.offsets import DateOffset
105106
from pandas._typing import (
106107
S1,
107108
S2,
@@ -125,7 +126,6 @@ from pandas._typing import (
125126
FilePath,
126127
FillnaOptions,
127128
FloatDtypeArg,
128-
Frequency,
129129
GroupByObjectNonScalar,
130130
HashableT1,
131131
IgnoreRaise,
@@ -1130,10 +1130,10 @@ class Series(IndexOpsMixin[S1], NDFrame):
11301130
) -> Series[S1]: ...
11311131
def shift(
11321132
self,
1133-
periods: int = ...,
1134-
freq: Frequency | timedelta | None = ...,
1135-
axis: AxisIndex = ...,
1136-
fill_value: object | None = ...,
1133+
periods: int | Sequence[int] = ...,
1134+
freq: DateOffset | timedelta | _str | None = ...,
1135+
axis: Axis = ...,
1136+
fill_value: Scalar | NAType | None = ...,
11371137
) -> UnknownSeries: ...
11381138
def info(
11391139
self,
@@ -1549,11 +1549,11 @@ class Series(IndexOpsMixin[S1], NDFrame):
15491549
def pct_change(
15501550
self,
15511551
periods: int = ...,
1552-
fill_method: _str = ...,
1553-
limit: int | None = ...,
1554-
freq=...,
1555-
**kwargs: Any, # TODO: make more precise https://github.com/pandas-dev/pandas-stubs/issues/1169
1556-
) -> Series[S1]: ...
1552+
fill_method: None = ...,
1553+
freq: DateOffset | timedelta | _str | None = ...,
1554+
*,
1555+
fill_value: Scalar | NAType | None = ...,
1556+
) -> Series[float]: ...
15571557
def first_valid_index(self) -> Scalar: ...
15581558
def last_valid_index(self) -> Scalar: ...
15591559
@overload

tests/test_frame.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2437,7 +2437,13 @@ def test_groupby_series_methods() -> None:
24372437

24382438
def test_dataframe_pct_change() -> None:
24392439
df = pd.DataFrame({"x": [1, 2, 2, 3, 3], "y": [10, 20, 30, 40, 50]})
2440-
df.pct_change(fill_method=None)
2440+
check(assert_type(df.pct_change(), pd.DataFrame), pd.DataFrame)
2441+
check(assert_type(df.pct_change(fill_method=None), pd.DataFrame), pd.DataFrame)
2442+
check(
2443+
assert_type(df.pct_change(axis="columns", periods=-1), pd.DataFrame),
2444+
pd.DataFrame,
2445+
)
2446+
check(assert_type(df.pct_change(fill_value=0), pd.DataFrame), pd.DataFrame)
24412447

24422448

24432449
def test_indexslice_setitem():

tests/test_series.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,26 @@ def test_types_shift() -> None:
441441
check(assert_type(s.shift(freq="1D"), pd.Series), pd.Series, np.integer)
442442

443443

444+
def test_series_pct_change() -> None:
445+
s = pd.Series([1, 2, 3], index=pd.date_range("2020", periods=3))
446+
check(assert_type(s.pct_change(), "pd.Series[float]"), pd.Series, np.floating)
447+
check(
448+
assert_type(s.pct_change(fill_method=None), "pd.Series[float]"),
449+
pd.Series,
450+
np.floating,
451+
)
452+
check(
453+
assert_type(s.pct_change(periods=-1), "pd.Series[float]"),
454+
pd.Series,
455+
np.floating,
456+
)
457+
check(
458+
assert_type(s.pct_change(fill_value=0), "pd.Series[float]"),
459+
pd.Series,
460+
np.floating,
461+
)
462+
463+
444464
def test_types_rank() -> None:
445465
s = pd.Series([1, 1, 2, 5, 6, np.nan])
446466
s.rank()

0 commit comments

Comments
 (0)