diff --git a/docs/src/python/user-guide/expressions/aggregation.py b/docs/src/python/user-guide/expressions/aggregation.py index e25917b2de38..f2d75cbd3726 100644 --- a/docs/src/python/user-guide/expressions/aggregation.py +++ b/docs/src/python/user-guide/expressions/aggregation.py @@ -6,7 +6,7 @@ # --8<-- [start:dataframe] url = "https://theunitedstates.io/congress-legislators/legislators-historical.csv" -dtypes = { +schema_overrides = { "first_name": pl.Categorical, "gender": pl.Categorical, "type": pl.Categorical, @@ -14,7 +14,7 @@ "party": pl.Categorical, } -dataset = pl.read_csv(url, dtypes=dtypes).with_columns( +dataset = pl.read_csv(url, schema_overrides=schema_overrides).with_columns( pl.col("birthday").str.to_date(strict=False) ) # --8<-- [end:dataframe] diff --git a/py-polars/polars/_utils/deprecation.py b/py-polars/polars/_utils/deprecation.py index 9c3382f4982d..6f53205a4f50 100644 --- a/py-polars/polars/_utils/deprecation.py +++ b/py-polars/polars/_utils/deprecation.py @@ -156,7 +156,7 @@ def _rename_keyword_argument( ) raise TypeError(msg) issue_deprecation_warning( - f"`the argument {old_name}` for `{func_name}` is deprecated." + f"The argument `{old_name}` for `{func_name}` is deprecated." f" It has been renamed to `{new_name}`.", version=version, ) diff --git a/py-polars/polars/io/csv/batched_reader.py b/py-polars/polars/io/csv/batched_reader.py index 5a007239cb86..b622d4fcc739 100644 --- a/py-polars/polars/io/csv/batched_reader.py +++ b/py-polars/polars/io/csv/batched_reader.py @@ -35,7 +35,7 @@ def __init__( comment_prefix: str | None = None, quote_char: str | None = '"', skip_rows: int = 0, - dtypes: None | (SchemaDict | Sequence[PolarsDataType]) = None, + schema_overrides: SchemaDict | Sequence[PolarsDataType] | None = None, null_values: str | Sequence[str] | dict[str, str] | None = None, missing_utf8_is_empty_string: bool = False, ignore_errors: bool = False, @@ -61,15 +61,15 @@ def __init__( dtype_list: Sequence[tuple[str, PolarsDataType]] | None = None dtype_slice: Sequence[PolarsDataType] | None = None - if dtypes is not None: - if isinstance(dtypes, dict): + if schema_overrides is not None: + if isinstance(schema_overrides, dict): dtype_list = [] - for k, v in dtypes.items(): + for k, v in schema_overrides.items(): dtype_list.append((k, py_type_to_dtype(v))) - elif isinstance(dtypes, Sequence): - dtype_slice = dtypes + elif isinstance(schema_overrides, Sequence): + dtype_slice = schema_overrides else: - msg = "`dtypes` arg should be list or dict" + msg = "`schema_overrides` arg should be list or dict" raise TypeError(msg) processed_null_values = _process_null_values(null_values) diff --git a/py-polars/polars/io/csv/functions.py b/py-polars/polars/io/csv/functions.py index dc71e264ba6a..5693e089999b 100644 --- a/py-polars/polars/io/csv/functions.py +++ b/py-polars/polars/io/csv/functions.py @@ -32,6 +32,7 @@ from polars.type_aliases import CsvEncoding, PolarsDataType, SchemaDict +@deprecate_renamed_parameter("dtypes", "schema_overrides", version="0.20.31") @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4") @deprecate_renamed_parameter( @@ -47,8 +48,10 @@ def read_csv( comment_prefix: str | None = None, quote_char: str | None = '"', skip_rows: int = 0, - dtypes: Mapping[str, PolarsDataType] | Sequence[PolarsDataType] | None = None, schema: SchemaDict | None = None, + schema_overrides: ( + Mapping[str, PolarsDataType] | Sequence[PolarsDataType] | None + ) = None, null_values: str | Sequence[str] | dict[str, str] | None = None, missing_utf8_is_empty_string: bool = False, ignore_errors: bool = False, @@ -103,12 +106,12 @@ def read_csv( Set to None to turn off special handling and escaping of quotes. skip_rows Start reading after `skip_rows` lines. - dtypes - Overwrite dtypes for specific or all columns during schema inference. schema Provide the schema. This means that polars doesn't do schema inference. - This argument expects the complete schema, whereas `dtypes` can be used - to partially overwrite a schema. + This argument expects the complete schema, whereas `schema_overrides` can be + used to partially overwrite a schema. + schema_overrides + Overwrite dtypes for specific or all columns during schema inference. null_values Values to interpret as null values. You can provide a: @@ -124,7 +127,7 @@ def read_csv( Try to keep reading lines if some lines yield errors. Before using this option, try to increase the number of lines used for schema inference with e.g `infer_schema_length=10000` or override automatic dtype - inference for specific columns with the `dtypes` option or use + inference for specific columns with the `schema_overrides` option or use `infer_schema_length=0` to read all columns as `pl.String` to check which values might cause an issue. try_parse_dates @@ -205,7 +208,7 @@ def read_csv( If the schema is inferred incorrectly (e.g. as `pl.Int64` instead of `pl.Float64`), try to increase the number of lines used to infer the schema with `infer_schema_length` or override the inferred dtype for those columns with - `dtypes`. + `schema_overrides`. This operation defaults to a `rechunk` operation at the end, meaning that all data will be stored continuously in memory. Set `rechunk=False` if you are benchmarking @@ -254,7 +257,7 @@ def read_csv( if ( use_pyarrow - and dtypes is None + and schema_overrides is None and n_rows is None and n_threads is None and not low_memory @@ -321,9 +324,9 @@ def read_csv( return _update_columns(df, new_columns) return df - if projection and dtypes and isinstance(dtypes, list): - if len(projection) < len(dtypes): - msg = "more dtypes overrides are specified than there are selected columns" + if projection and schema_overrides and isinstance(schema_overrides, list): + if len(projection) < len(schema_overrides): + msg = "more schema overrides are specified than there are selected columns" raise ValueError(msg) # Fix list of dtypes when used together with projection as polars CSV reader @@ -331,22 +334,22 @@ def read_csv( dtypes_list: list[PolarsDataType] = [String] * (max(projection) + 1) for idx, column_idx in enumerate(projection): - if idx < len(dtypes): - dtypes_list[column_idx] = dtypes[idx] + if idx < len(schema_overrides): + dtypes_list[column_idx] = schema_overrides[idx] - dtypes = dtypes_list + schema_overrides = dtypes_list - if columns and dtypes and isinstance(dtypes, list): - if len(columns) < len(dtypes): + if columns and schema_overrides and isinstance(schema_overrides, list): + if len(columns) < len(schema_overrides): msg = "more dtypes overrides are specified than there are selected columns" raise ValueError(msg) # Map list of dtypes when used together with selected columns as a dtypes dict # so the dtypes are applied to the correct column instead of the first x # columns. - dtypes = dict(zip(columns, dtypes)) + schema_overrides = dict(zip(columns, schema_overrides)) - if new_columns and dtypes and isinstance(dtypes, dict): + if new_columns and schema_overrides and isinstance(schema_overrides, dict): current_columns = None # As new column names are not available yet while parsing the CSV file, rename @@ -387,26 +390,26 @@ def read_csv( else: # When a header is present, column names are not known yet. - if len(dtypes) <= len(new_columns): + if len(schema_overrides) <= len(new_columns): # If dtypes dictionary contains less or same amount of values than new # column names a list of dtypes can be created if all listed column # names in dtypes dictionary appear in the first consecutive new column # names. dtype_list = [ - dtypes[new_column_name] - for new_column_name in new_columns[0 : len(dtypes)] - if new_column_name in dtypes + schema_overrides[new_column_name] + for new_column_name in new_columns[0 : len(schema_overrides)] + if new_column_name in schema_overrides ] - if len(dtype_list) == len(dtypes): - dtypes = dtype_list + if len(dtype_list) == len(schema_overrides): + schema_overrides = dtype_list - if current_columns and isinstance(dtypes, dict): + if current_columns and isinstance(schema_overrides, dict): new_to_current = dict(zip(new_columns, current_columns)) # Change new column names to current column names in dtype. - dtypes = { + schema_overrides = { new_to_current.get(column_name, column_name): column_dtype - for column_name, column_dtype in dtypes.items() + for column_name, column_dtype in schema_overrides.items() } with prepare_file_arg( @@ -424,7 +427,7 @@ def read_csv( comment_prefix=comment_prefix, quote_char=quote_char, skip_rows=skip_rows, - dtypes=dtypes, + schema_overrides=schema_overrides, schema=schema, null_values=null_values, missing_utf8_is_empty_string=missing_utf8_is_empty_string, @@ -462,8 +465,8 @@ def _read_csv_impl( comment_prefix: str | None = None, quote_char: str | None = '"', skip_rows: int = 0, - dtypes: None | (SchemaDict | Sequence[PolarsDataType]) = None, schema: None | SchemaDict = None, + schema_overrides: None | (SchemaDict | Sequence[PolarsDataType]) = None, null_values: str | Sequence[str] | dict[str, str] | None = None, missing_utf8_is_empty_string: bool = False, ignore_errors: bool = False, @@ -497,15 +500,15 @@ def _read_csv_impl( dtype_list: Sequence[tuple[str, PolarsDataType]] | None = None dtype_slice: Sequence[PolarsDataType] | None = None - if dtypes is not None: - if isinstance(dtypes, dict): + if schema_overrides is not None: + if isinstance(schema_overrides, dict): dtype_list = [] - for k, v in dtypes.items(): + for k, v in schema_overrides.items(): dtype_list.append((k, py_type_to_dtype(v))) - elif isinstance(dtypes, Sequence): - dtype_slice = dtypes + elif isinstance(schema_overrides, Sequence): + dtype_slice = schema_overrides else: - msg = f"`dtypes` should be of type list or dict, got {type(dtypes).__name__!r}" + msg = f"`schema_overrides` should be of type list or dict, got {type(schema_overrides).__name__!r}" raise TypeError(msg) processed_null_values = _process_null_values(null_values) @@ -518,8 +521,8 @@ def _read_csv_impl( dtypes_dict = dict(dtype_list) if dtype_slice is not None: msg = ( - "cannot use glob patterns and unnamed dtypes as `dtypes` argument" - "\n\nUse `dtypes`: Mapping[str, Type[DataType]]" + "cannot use glob patterns and unnamed dtypes as `schema_overrides` argument" + "\n\nUse `schema_overrides`: Mapping[str, Type[DataType]]" ) raise ValueError(msg) from polars import scan_csv @@ -531,8 +534,8 @@ def _read_csv_impl( comment_prefix=comment_prefix, quote_char=quote_char, skip_rows=skip_rows, - dtypes=dtypes_dict, schema=schema, + schema_overrides=dtypes_dict, null_values=null_values, missing_utf8_is_empty_string=missing_utf8_is_empty_string, ignore_errors=ignore_errors, @@ -597,6 +600,7 @@ def _read_csv_impl( return wrap_df(pydf) +@deprecate_renamed_parameter("dtypes", "schema_overrides", version="0.20.31") @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4") @deprecate_renamed_parameter( @@ -612,7 +616,9 @@ def read_csv_batched( comment_prefix: str | None = None, quote_char: str | None = '"', skip_rows: int = 0, - dtypes: Mapping[str, PolarsDataType] | Sequence[PolarsDataType] | None = None, + schema_overrides: ( + Mapping[str, PolarsDataType] | Sequence[PolarsDataType] | None + ) = None, null_values: str | Sequence[str] | dict[str, str] | None = None, missing_utf8_is_empty_string: bool = False, ignore_errors: bool = False, @@ -668,7 +674,7 @@ def read_csv_batched( Set to None to turn off special handling and escaping of quotes. skip_rows Start reading after `skip_rows` lines. - dtypes + schema_overrides Overwrite dtypes during inference. null_values Values to interpret as null values. You can provide a: @@ -787,9 +793,9 @@ def read_csv_batched( ) raise ValueError(msg) - if projection and dtypes and isinstance(dtypes, list): - if len(projection) < len(dtypes): - msg = "more dtypes overrides are specified than there are selected columns" + if projection and schema_overrides and isinstance(schema_overrides, list): + if len(projection) < len(schema_overrides): + msg = "more schema overrides are specified than there are selected columns" raise ValueError(msg) # Fix list of dtypes when used together with projection as polars CSV reader @@ -797,22 +803,22 @@ def read_csv_batched( dtypes_list: list[PolarsDataType] = [String] * (max(projection) + 1) for idx, column_idx in enumerate(projection): - if idx < len(dtypes): - dtypes_list[column_idx] = dtypes[idx] + if idx < len(schema_overrides): + dtypes_list[column_idx] = schema_overrides[idx] - dtypes = dtypes_list + schema_overrides = dtypes_list - if columns and dtypes and isinstance(dtypes, list): - if len(columns) < len(dtypes): - msg = "more dtypes overrides are specified than there are selected columns" + if columns and schema_overrides and isinstance(schema_overrides, list): + if len(columns) < len(schema_overrides): + msg = "more schema overrides are specified than there are selected columns" raise ValueError(msg) # Map list of dtypes when used together with selected columns as a dtypes dict # so the dtypes are applied to the correct column instead of the first x # columns. - dtypes = dict(zip(columns, dtypes)) + schema_overrides = dict(zip(columns, schema_overrides)) - if new_columns and dtypes and isinstance(dtypes, dict): + if new_columns and schema_overrides and isinstance(schema_overrides, dict): current_columns = None # As new column names are not available yet while parsing the CSV file, rename @@ -847,26 +853,26 @@ def read_csv_batched( else: # When a header is present, column names are not known yet. - if len(dtypes) <= len(new_columns): + if len(schema_overrides) <= len(new_columns): # If dtypes dictionary contains less or same amount of values than new # column names a list of dtypes can be created if all listed column # names in dtypes dictionary appear in the first consecutive new column # names. dtype_list = [ - dtypes[new_column_name] - for new_column_name in new_columns[0 : len(dtypes)] - if new_column_name in dtypes + schema_overrides[new_column_name] + for new_column_name in new_columns[0 : len(schema_overrides)] + if new_column_name in schema_overrides ] - if len(dtype_list) == len(dtypes): - dtypes = dtype_list + if len(dtype_list) == len(schema_overrides): + schema_overrides = dtype_list - if current_columns and isinstance(dtypes, dict): + if current_columns and isinstance(schema_overrides, dict): new_to_current = dict(zip(new_columns, current_columns)) # Change new column names to current column names in dtype. - dtypes = { + schema_overrides = { new_to_current.get(column_name, column_name): column_dtype - for column_name, column_dtype in dtypes.items() + for column_name, column_dtype in schema_overrides.items() } return BatchedCsvReader( @@ -877,7 +883,7 @@ def read_csv_batched( comment_prefix=comment_prefix, quote_char=quote_char, skip_rows=skip_rows, - dtypes=dtypes, + schema_overrides=schema_overrides, null_values=null_values, missing_utf8_is_empty_string=missing_utf8_is_empty_string, ignore_errors=ignore_errors, @@ -901,6 +907,7 @@ def read_csv_batched( ) +@deprecate_renamed_parameter("dtypes", "schema_overrides", version="0.20.31") @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4") @deprecate_renamed_parameter( @@ -914,8 +921,8 @@ def scan_csv( comment_prefix: str | None = None, quote_char: str | None = '"', skip_rows: int = 0, - dtypes: SchemaDict | Sequence[PolarsDataType] | None = None, schema: SchemaDict | None = None, + schema_overrides: SchemaDict | Sequence[PolarsDataType] | None = None, null_values: str | Sequence[str] | dict[str, str] | None = None, missing_utf8_is_empty_string: bool = False, ignore_errors: bool = False, @@ -963,14 +970,14 @@ def scan_csv( skip_rows Start reading after `skip_rows` lines. The header will be parsed at this offset. - dtypes + schema + Provide the schema. This means that polars doesn't do schema inference. + This argument expects the complete schema, whereas `schema_overrides` can be + used to partially overwrite a schema. + schema_overrides Overwrite dtypes during inference; should be a {colname:dtype,} dict or, if providing a list of strings to `new_columns`, a list of dtypes of the same length. - schema - Provide the schema. This means that polars doesn't do schema inference. - This argument expects the complete schema, whereas `dtypes` can be used - to partially overwrite a schema. null_values Values to interpret as null values. You can provide a: @@ -1085,7 +1092,7 @@ def scan_csv( >>> pl.scan_csv( ... path, ... new_columns=["idx", "txt"], - ... dtypes=[pl.UInt16, pl.String], + ... schema_overrides=[pl.UInt16, pl.String], ... ).collect() shape: (4, 2) ┌─────┬──────┐ @@ -1099,15 +1106,15 @@ def scan_csv( │ 4 ┆ read │ └─────┴──────┘ """ - if not new_columns and isinstance(dtypes, Sequence): - msg = f"expected 'dtypes' dict, found {type(dtypes).__name__!r}" + if not new_columns and isinstance(schema_overrides, Sequence): + msg = f"expected 'schema_overrides' dict, found {type(schema_overrides).__name__!r}" raise TypeError(msg) elif new_columns: if with_column_names: msg = "cannot set both `with_column_names` and `new_columns`; mutually exclusive" raise ValueError(msg) - if dtypes and isinstance(dtypes, Sequence): - dtypes = dict(zip(new_columns, dtypes)) + if schema_overrides and isinstance(schema_overrides, Sequence): + schema_overrides = dict(zip(new_columns, schema_overrides)) # wrap new column names as a callable def with_column_names(cols: list[str]) -> list[str]: @@ -1131,7 +1138,7 @@ def with_column_names(cols: list[str]) -> list[str]: comment_prefix=comment_prefix, quote_char=quote_char, skip_rows=skip_rows, - dtypes=dtypes, # type: ignore[arg-type] + schema_overrides=schema_overrides, # type: ignore[arg-type] schema=schema, null_values=null_values, missing_utf8_is_empty_string=missing_utf8_is_empty_string, @@ -1163,8 +1170,8 @@ def _scan_csv_impl( comment_prefix: str | None = None, quote_char: str | None = '"', skip_rows: int = 0, - dtypes: SchemaDict | None = None, schema: SchemaDict | None = None, + schema_overrides: SchemaDict | None = None, null_values: str | Sequence[str] | dict[str, str] | None = None, missing_utf8_is_empty_string: bool = False, ignore_errors: bool = False, @@ -1186,9 +1193,9 @@ def _scan_csv_impl( glob: bool = True, ) -> LazyFrame: dtype_list: list[tuple[str, PolarsDataType]] | None = None - if dtypes is not None: + if schema_overrides is not None: dtype_list = [] - for k, v in dtypes.items(): + for k, v in schema_overrides.items(): dtype_list.append((k, py_type_to_dtype(v))) processed_null_values = _process_null_values(null_values) diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index 4b3a143f865b..2499f45afd2a 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -702,13 +702,22 @@ def _csv_buffer_to_frame( if read_options is None: read_options = {} if schema_overrides: - if (csv_dtypes := read_options.get("dtypes", {})) and set( - csv_dtypes - ).intersection(schema_overrides): + csv_dtypes = read_options.get("dtypes", {}) + if csv_dtypes: + issue_deprecation_warning( + "The `dtypes` parameter for `read_csv` is deprecated. It has been renamed to `schema_overrides`.", + version="0.20.31", + ) + csv_schema_overrides = read_options.get("schema_overrides", csv_dtypes) + + if csv_schema_overrides and set(csv_schema_overrides).intersection( + schema_overrides + ): msg = "cannot specify columns in both `schema_overrides` and `read_options['dtypes']`" raise ParameterCollisionError(msg) + read_options = read_options.copy() - read_options["dtypes"] = {**csv_dtypes, **schema_overrides} + read_options["schema_overrides"] = {**csv_schema_overrides, **schema_overrides} # otherwise rewind the buffer and parse as csv csv.seek(0) diff --git a/py-polars/tests/unit/datatypes/test_categorical.py b/py-polars/tests/unit/datatypes/test_categorical.py index c310cab18fe1..a3f97877a1ea 100644 --- a/py-polars/tests/unit/datatypes/test_categorical.py +++ b/py-polars/tests/unit/datatypes/test_categorical.py @@ -75,7 +75,7 @@ def test_read_csv_categorical() -> None: f = io.BytesIO() f.write(b"col1,col2,col3,col4,col5,col6\n'foo',2,3,4,5,6\n'bar',8,9,10,11,12") f.seek(0) - df = pl.read_csv(f, has_header=True, dtypes={"col1": pl.Categorical}) + df = pl.read_csv(f, has_header=True, schema_overrides={"col1": pl.Categorical}) assert df["col1"].dtype == pl.Categorical diff --git a/py-polars/tests/unit/datatypes/test_decimal.py b/py-polars/tests/unit/datatypes/test_decimal.py index 9d14e0b8c549..445fbd35fa80 100644 --- a/py-polars/tests/unit/datatypes/test_decimal.py +++ b/py-polars/tests/unit/datatypes/test_decimal.py @@ -190,7 +190,7 @@ def test_read_csv_decimal(monkeypatch: Any) -> None: 1.1,a 0.01,a""" - df = pl.read_csv(csv.encode(), dtypes={"a": pl.Decimal(scale=2)}) + df = pl.read_csv(csv.encode(), schema_overrides={"a": pl.Decimal(scale=2)}) assert df.dtypes == [pl.Decimal(precision=None, scale=2), pl.String] assert df["a"].to_list() == [ D("123.12"), diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py index d92deefa4ce2..639f2a89148e 100644 --- a/py-polars/tests/unit/io/test_csv.py +++ b/py-polars/tests/unit/io/test_csv.py @@ -50,7 +50,7 @@ def test_quoted_date() -> None: def test_date_pattern_with_datetime_override_10826() -> None: result = pl.read_csv( source=io.StringIO("col\n2023-01-01\n2023-02-01\n2023-03-01"), - dtypes={"col": pl.Datetime}, + schema_overrides={"col": pl.Datetime}, ) expected = pl.Series( "col", [datetime(2023, 1, 1), datetime(2023, 2, 1), datetime(2023, 3, 1)] @@ -59,7 +59,7 @@ def test_date_pattern_with_datetime_override_10826() -> None: result = pl.read_csv( source=io.StringIO("col\n2023-01-01T01:02:03\n2023-02-01\n2023-03-01"), - dtypes={"col": pl.Datetime}, + schema_overrides={"col": pl.Datetime}, ) expected = pl.Series( "col", @@ -362,7 +362,7 @@ def test_partial_dtype_overwrite() -> None: """ ) f = io.StringIO(csv) - df = pl.read_csv(f, dtypes=[pl.String]) + df = pl.read_csv(f, schema_overrides=[pl.String]) assert df.dtypes == [pl.String, pl.Int64, pl.Int64] @@ -375,7 +375,7 @@ def test_dtype_overwrite_with_column_name_selection() -> None: """ ) f = io.StringIO(csv) - df = pl.read_csv(f, columns=["c", "b", "d"], dtypes=[pl.Int32, pl.String]) + df = pl.read_csv(f, columns=["c", "b", "d"], schema_overrides=[pl.Int32, pl.String]) assert df.dtypes == [pl.String, pl.Int32, pl.Int64] @@ -388,7 +388,7 @@ def test_dtype_overwrite_with_column_idx_selection() -> None: """ ) f = io.StringIO(csv) - df = pl.read_csv(f, columns=[2, 1, 3], dtypes=[pl.Int32, pl.String]) + df = pl.read_csv(f, columns=[2, 1, 3], schema_overrides=[pl.Int32, pl.String]) # Columns without an explicit dtype set will get pl.String if dtypes is a list # if the column selection is done with column indices instead of column names. assert df.dtypes == [pl.String, pl.Int32, pl.String] @@ -488,7 +488,7 @@ def test_column_rename_and_dtype_overwrite() -> None: df = pl.read_csv( f, new_columns=["A", "B", "C"], - dtypes={"A": pl.String, "B": pl.Int64, "C": pl.Float32}, + schema_overrides={"A": pl.String, "B": pl.Int64, "C": pl.Float32}, ) assert df.dtypes == [pl.String, pl.Int64, pl.Float32] @@ -497,7 +497,7 @@ def test_column_rename_and_dtype_overwrite() -> None: f, columns=["a", "c"], new_columns=["A", "C"], - dtypes={"A": pl.String, "C": pl.Float32}, + schema_overrides={"A": pl.String, "C": pl.Float32}, ) assert df.dtypes == [pl.String, pl.Float32] @@ -511,7 +511,7 @@ def test_column_rename_and_dtype_overwrite() -> None: df = pl.read_csv( f, new_columns=["A", "B", "C"], - dtypes={"A": pl.String, "C": pl.Float32}, + schema_overrides={"A": pl.String, "C": pl.Float32}, has_header=False, ) assert df.dtypes == [pl.String, pl.Int64, pl.Float32] @@ -755,7 +755,7 @@ def test_ignore_try_parse_dates() -> None: dtypes: dict[str, type[pl.DataType]] = { k: pl.String for k in headers } # Forces String type for every column - df = pl.read_csv(csv, columns=headers, dtypes=dtypes) + df = pl.read_csv(csv, columns=headers, schema_overrides=dtypes) assert df.dtypes == [pl.String, pl.String, pl.String] @@ -786,7 +786,7 @@ def test_csv_date_handling() -> None: out = pl.read_csv(csv.encode(), try_parse_dates=True) assert_frame_equal(out, expected) dtypes = {"date": pl.Date} - out = pl.read_csv(csv.encode(), dtypes=dtypes) + out = pl.read_csv(csv.encode(), schema_overrides=dtypes) assert_frame_equal(out, expected) @@ -837,7 +837,9 @@ def test_csv_date_dtype_ignore_errors() -> None: !! """ ) - out = pl.read_csv(csv.encode(), ignore_errors=True, dtypes={"date": pl.Date}) + out = pl.read_csv( + csv.encode(), ignore_errors=True, schema_overrides={"date": pl.Date} + ) expected = pl.DataFrame( { "date": [ @@ -865,7 +867,9 @@ def test_csv_globbing(io_files_path: Path) -> None: assert df.row(0) == ("vegetables", 2) with pytest.raises(ValueError): - _ = pl.read_csv(path, dtypes=[pl.String, pl.Int64, pl.Int64, pl.Int64]) + _ = pl.read_csv( + path, schema_overrides=[pl.String, pl.Int64, pl.Int64, pl.Int64] + ) dtypes = { "category": pl.String, @@ -874,7 +878,7 @@ def test_csv_globbing(io_files_path: Path) -> None: "sugars_g": pl.Int32, } - df = pl.read_csv(path, dtypes=dtypes) + df = pl.read_csv(path, schema_overrides=dtypes) assert df.dtypes == list(dtypes.values()) @@ -959,7 +963,7 @@ def test_escaped_null_values() -> None: df = pl.read_csv( f, null_values={"a": "None", "b": "n/a", "c": "NA"}, - dtypes={"a": pl.String, "b": pl.Int64, "c": pl.Float64}, + schema_overrides={"a": pl.String, "b": pl.Int64, "c": pl.Float64}, ) assert df[1, "a"] is None assert df[0, "b"] is None @@ -1058,7 +1062,7 @@ def test_csv_overwrite_datetime_dtype( result = pl.read_csv( io.StringIO(data), try_parse_dates=try_parse_dates, - dtypes={"a": pl.Datetime(time_unit)}, + schema_overrides={"a": pl.Datetime(time_unit)}, ) expected = pl.DataFrame( { @@ -1205,7 +1209,7 @@ def test_csv_dtype_overwrite_bool() -> None: csv = "a, b\n" + ",false\n" + ",false\n" + ",false" df = pl.read_csv( csv.encode(), - dtypes={"a": pl.Boolean, "b": pl.Boolean}, + schema_overrides={"a": pl.Boolean, "b": pl.Boolean}, ) assert df.dtypes == [pl.Boolean, pl.Boolean] @@ -1401,7 +1405,9 @@ def test_csv_categorical_lifetime() -> None: """ ) - df = pl.read_csv(csv.encode(), dtypes={"a": pl.Categorical, "b": pl.Categorical}) + df = pl.read_csv( + csv.encode(), schema_overrides={"a": pl.Categorical, "b": pl.Categorical} + ) assert df.dtypes == [pl.Categorical, pl.Categorical] assert df.to_dict(as_series=False) == { "a": ["needs_escape", ' "needs escape foo', ' "needs escape foo'], @@ -1416,9 +1422,9 @@ def test_csv_categorical_categorical_merge() -> None: f = io.BytesIO() pl.DataFrame({"x": ["A"] * N + ["B"] * N}).write_csv(f) f.seek(0) - assert pl.read_csv(f, dtypes={"x": pl.Categorical}, sample_size=10).unique( - maintain_order=True - )["x"].to_list() == ["A", "B"] + assert pl.read_csv( + f, schema_overrides={"x": pl.Categorical}, sample_size=10 + ).unique(maintain_order=True)["x"].to_list() == ["A", "B"] def test_batched_csv_reader(foods_file_path: Path) -> None: @@ -1522,7 +1528,7 @@ def test_csv_single_categorical_null() -> None: df = pl.read_csv( f, - dtypes={"y": pl.Categorical}, + schema_overrides={"y": pl.Categorical}, ) assert df.dtypes == [pl.String, pl.Categorical, pl.String] @@ -1536,7 +1542,9 @@ def test_csv_quoted_missing() -> None: '"1"|"Free text without a linebreak"|""|"789"\n' '"0"|"Free text with \ntwo \nlinebreaks"|"101112"|"131415"' ) - result = pl.read_csv(csv.encode(), separator="|", dtypes={"col3": pl.Int32}) + result = pl.read_csv( + csv.encode(), separator="|", schema_overrides={"col3": pl.Int32} + ) expected = pl.DataFrame( { "col1": [0, 1, 0], @@ -1580,7 +1588,7 @@ def test_csv_scan_categorical(tmp_path: Path) -> None: file_path = tmp_path / "test_csv_scan_categorical.csv" df.write_csv(file_path) - result = pl.scan_csv(file_path, dtypes={"x": pl.Categorical}).collect() + result = pl.scan_csv(file_path, schema_overrides={"x": pl.Categorical}).collect() assert result["x"].dtype == pl.Categorical @@ -1796,14 +1804,14 @@ def test_ignore_errors_casting_dtypes() -> None: assert pl.read_csv( source=io.StringIO(csv), - dtypes={"inventory": pl.Int8}, + schema_overrides={"inventory": pl.Int8}, ignore_errors=True, ).to_dict(as_series=False) == {"inventory": [10, None, None, 90]} with pytest.raises(pl.ComputeError): pl.read_csv( source=io.StringIO(csv), - dtypes={"inventory": pl.Int8}, + schema_overrides={"inventory": pl.Int8}, ignore_errors=False, ) @@ -1813,7 +1821,7 @@ def test_ignore_errors_date_parser() -> None: with pytest.raises(pl.ComputeError): pl.read_csv( source=io.StringIO(data_invalid_date), - dtypes={"date": pl.Date}, + schema_overrides={"date": pl.Date}, ignore_errors=False, ) @@ -1971,8 +1979,10 @@ def test_read_csv_invalid_dtypes() -> None: """ ) f = io.StringIO(csv) - with pytest.raises(TypeError, match="`dtypes` should be of type list or dict"): - pl.read_csv(f, dtypes={pl.Int64, pl.String}) # type: ignore[arg-type] + with pytest.raises( + TypeError, match="`schema_overrides` should be of type list or dict" + ): + pl.read_csv(f, schema_overrides={pl.Int64, pl.String}) # type: ignore[arg-type] @pytest.mark.parametrize("columns", [["b"], "b"]) @@ -2130,3 +2140,23 @@ def test_no_glob(tmpdir: Path) -> None: df.write_csv(str(p)) p = tmpdir / "*.csv" assert_frame_equal(pl.read_csv(str(p), glob=False), df) + + +def test_read_csv_dtypes_deprecated() -> None: + csv = textwrap.dedent( + """\ + a,b,c + 1,2,3 + 4,5,6 + """ + ) + f = io.StringIO(csv) + + with pytest.deprecated_call(): + df = pl.read_csv(f, dtypes=[pl.Int8, pl.Int8, pl.Int8]) # type: ignore[call-arg] + + expected = pl.DataFrame( + {"a": [1, 4], "b": [2, 5], "c": [3, 6]}, + schema={"a": pl.Int8, "b": pl.Int8, "c": pl.Int8}, + ) + assert_frame_equal(df, expected) diff --git a/py-polars/tests/unit/io/test_lazy_csv.py b/py-polars/tests/unit/io/test_lazy_csv.py index bffaedc50b28..6af6eb67f3d5 100644 --- a/py-polars/tests/unit/io/test_lazy_csv.py +++ b/py-polars/tests/unit/io/test_lazy_csv.py @@ -80,7 +80,7 @@ def test_scan_csv_schema_overwrite_and_dtypes_overwrite( file_path = io_files_path / file_name df = pl.scan_csv( file_path, - dtypes={"calories_foo": pl.String, "fats_g_foo": pl.Float32}, + schema_overrides={"calories_foo": pl.String, "fats_g_foo": pl.Float32}, with_column_names=lambda names: [f"{a}_foo" for a in names], ).collect() assert df.dtypes == [pl.String, pl.String, pl.Float32, pl.Int64] @@ -100,7 +100,7 @@ def test_scan_csv_schema_overwrite_and_small_dtypes_overwrite( file_path = io_files_path / file_name df = pl.scan_csv( file_path, - dtypes={"calories_foo": pl.String, "sugars_g_foo": dtype}, + schema_overrides={"calories_foo": pl.String, "sugars_g_foo": dtype}, with_column_names=lambda names: [f"{a}_foo" for a in names], ).collect() assert df.dtypes == [pl.String, pl.String, pl.Float64, dtype] @@ -122,7 +122,7 @@ def test_scan_csv_schema_new_columns_dtypes( # assign 'new_columns', providing partial dtype overrides df1 = pl.scan_csv( file_path, - dtypes={"calories": pl.String, "sugars": dtype}, + schema_overrides={"calories": pl.String, "sugars": dtype}, new_columns=["category", "calories", "fats", "sugars"], ).collect() assert df1.dtypes == [pl.String, pl.String, pl.Float64, dtype] @@ -131,7 +131,7 @@ def test_scan_csv_schema_new_columns_dtypes( # assign 'new_columns' with 'dtypes' list df2 = pl.scan_csv( file_path, - dtypes=[pl.String, pl.String, pl.Float64, dtype], + schema_overrides=[pl.String, pl.String, pl.Float64, dtype], new_columns=["category", "calories", "fats", "sugars"], ).collect() assert df1.rows() == df2.rows() @@ -151,7 +151,7 @@ def test_scan_csv_schema_new_columns_dtypes( # partially rename columns / overwrite dtypes df4 = pl.scan_csv( file_path, - dtypes=[pl.String, pl.String], + schema_overrides=[pl.String, pl.String], new_columns=["category", "calories"], ).collect() assert df4.dtypes == [pl.String, pl.String, pl.Float64, pl.Int64] @@ -161,7 +161,7 @@ def test_scan_csv_schema_new_columns_dtypes( with pytest.raises(pl.ShapeError): pl.scan_csv( file_path, - dtypes=[pl.String, pl.String], + schema_overrides=[pl.String, pl.String], new_columns=["category", "calories", "c3", "c4", "c5"], ).collect() @@ -169,7 +169,7 @@ def test_scan_csv_schema_new_columns_dtypes( with pytest.raises(ValueError, match="mutually.exclusive"): pl.scan_csv( file_path, - dtypes=[pl.String, pl.String], + schema_overrides=[pl.String, pl.String], new_columns=["category", "calories", "fats", "sugars"], with_column_names=lambda cols: [col.capitalize() for col in cols], ).collect() @@ -248,7 +248,7 @@ def test_scan_csv_schema_overwrite_not_projected_8483(foods_file_path: Path) -> df = ( pl.scan_csv( foods_file_path, - dtypes={"calories": pl.String, "sugars_g": pl.Int8}, + schema_overrides={"calories": pl.String, "sugars_g": pl.Int8}, ) .select(pl.len()) .collect() diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index cccedfdb49ba..a0175e814d41 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -386,7 +386,7 @@ def test_schema_overrides(path_xlsx: Path, path_xlsb: Path, path_ods: Path) -> N df2 = pl.read_excel( path_xlsx, sheet_name="test4", - read_options={"dtypes": {"cardinality": pl.UInt16}}, + read_options={"schema_overrides": {"cardinality": pl.UInt16}}, ).drop_nulls() assert df2.schema["cardinality"] == pl.UInt16 @@ -398,7 +398,7 @@ def test_schema_overrides(path_xlsx: Path, path_xlsb: Path, path_ods: Path) -> N sheet_name="test4", schema_overrides={"cardinality": pl.UInt16}, read_options={ - "dtypes": { + "schema_overrides": { "rows_by_key": pl.Float32, "iter_groups": pl.Float32, }, @@ -439,7 +439,7 @@ def test_schema_overrides(path_xlsx: Path, path_xlsb: Path, path_ods: Path) -> N path_xlsx, sheet_name="test4", schema_overrides={"cardinality": pl.UInt16}, - read_options={"dtypes": {"cardinality": pl.Int32}}, + read_options={"schema_overrides": {"cardinality": pl.Int32}}, ) # read multiple sheets in conjunction with 'schema_overrides' diff --git a/py-polars/tests/unit/streaming/test_streaming_io.py b/py-polars/tests/unit/streaming/test_streaming_io.py index 83fa75a4b4a8..e2533128cab9 100644 --- a/py-polars/tests/unit/streaming/test_streaming_io.py +++ b/py-polars/tests/unit/streaming/test_streaming_io.py @@ -41,7 +41,9 @@ def test_scan_csv_overwrite_small_dtypes( io_files_path: Path, dtype: pl.DataType ) -> None: file_path = io_files_path / "foods1.csv" - df = pl.scan_csv(file_path, dtypes={"sugars_g": dtype}).collect(streaming=True) + df = pl.scan_csv(file_path, schema_overrides={"sugars_g": dtype}).collect( + streaming=True + ) assert df.dtypes == [pl.String, pl.Int64, pl.Float64, dtype]