From 8468b06144ac6581b18ff69f6cd911d999a6393f Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Sun, 24 Mar 2024 15:55:11 +0530 Subject: [PATCH 1/3] Dataframe includes non-dimension columns. This fixes an inconsistency between the actual columns in the dataframe and the df meta. --- xarray_sql/core.py | 1 + xarray_sql/df.py | 2 +- xarray_sql/df_test.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/xarray_sql/core.py b/xarray_sql/core.py index 82f4a3d..288eae0 100644 --- a/xarray_sql/core.py +++ b/xarray_sql/core.py @@ -7,6 +7,7 @@ Row = t.List[t.Any] +# deprecated def get_columns(ds: xr.Dataset) -> t.List[str]: return list(ds.dims.keys()) + list(ds.data_vars.keys()) diff --git a/xarray_sql/df.py b/xarray_sql/df.py index addb93e..1844328 100644 --- a/xarray_sql/df.py +++ b/xarray_sql/df.py @@ -101,7 +101,7 @@ def pivot(b: Block) -> pd.DataFrame: f'{"_".join(list(ds.data_vars.keys()))}' ) - columns = core.get_columns(ds) + columns = pivot(blocks[0]).columns # TODO(#18): Is it possible to pass the length (known now) here? meta = {c: ds[c].dtype for c in columns} diff --git a/xarray_sql/df_test.py b/xarray_sql/df_test.py index 8688ad0..9007145 100644 --- a/xarray_sql/df_test.py +++ b/xarray_sql/df_test.py @@ -3,11 +3,37 @@ import dask.dataframe as dd import numpy as np +import pandas as pd import xarray as xr from .df import explode, read_xarray, block_slices +def rand_wx(start: str, end: str) -> xr.Dataset: + np.random.seed(42) + lat = np.linspace(-90, 90, num=720) + lon = np.linspace(-180, 180, num=1440) + time = pd.date_range(start, end, freq='H') + level = np.array([1000, 500], dtype=np.int32) + reference_time = pd.Timestamp(start) + temperature = 15 + 8 * np.random.randn(720, 1440, len(time), len(level)) + precipitation = 10 * np.random.rand(720, 1440, len(time), len(level)) + return xr.Dataset( + data_vars=dict( + temperature=(['lat', 'lon', 'time', 'level'], temperature), + precipitation=(['lat', 'lon', 'time', 'level'], precipitation), + ), + coords=dict( + lat=lat, + lon=lon, + time=time, + level=level, + reference_time=reference_time, + ), + attrs=dict(description='Random weather.') + ) + + class DaskTestCase(unittest.TestCase): def setUp(self) -> None: @@ -18,6 +44,7 @@ def setUp(self) -> None: self.air_small = self.air.isel( time=slice(0, 12), lat=slice(0, 11), lon=slice(0, 10) ).chunk(self.chunks) + self.randwx = rand_wx('1995-01-13T00', '1995-01-13T01') class ExplodeTest(DaskTestCase): @@ -84,6 +111,13 @@ def test_chunk_perf(self): self.assertIsNotNone(df) self.assertEqual(len(df), np.prod(list(self.air.dims.values()))) + def test_column_metadata_preserved(self): + try: + _ = read_xarray(self.randwx, chunks=dict(time=24)).compute() + except ValueError as e: + if 'The columns in the computed data do not match the columns in the provided metadata' in str(e): + self.fail('Column metadata is incorrect.') + if __name__ == '__main__': unittest.main() From e5aff065fb5b8d8e25eacc7173117851b697c5b4 Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Sun, 24 Mar 2024 15:55:58 +0530 Subject: [PATCH 2/3] Reformat. --- xarray_sql/df_test.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/xarray_sql/df_test.py b/xarray_sql/df_test.py index 9007145..59953e6 100644 --- a/xarray_sql/df_test.py +++ b/xarray_sql/df_test.py @@ -19,18 +19,18 @@ def rand_wx(start: str, end: str) -> xr.Dataset: temperature = 15 + 8 * np.random.randn(720, 1440, len(time), len(level)) precipitation = 10 * np.random.rand(720, 1440, len(time), len(level)) return xr.Dataset( - data_vars=dict( - temperature=(['lat', 'lon', 'time', 'level'], temperature), - precipitation=(['lat', 'lon', 'time', 'level'], precipitation), - ), - coords=dict( - lat=lat, - lon=lon, - time=time, - level=level, - reference_time=reference_time, - ), - attrs=dict(description='Random weather.') + data_vars=dict( + temperature=(['lat', 'lon', 'time', 'level'], temperature), + precipitation=(['lat', 'lon', 'time', 'level'], precipitation), + ), + coords=dict( + lat=lat, + lon=lon, + time=time, + level=level, + reference_time=reference_time, + ), + attrs=dict(description='Random weather.'), ) @@ -115,7 +115,11 @@ def test_column_metadata_preserved(self): try: _ = read_xarray(self.randwx, chunks=dict(time=24)).compute() except ValueError as e: - if 'The columns in the computed data do not match the columns in the provided metadata' in str(e): + if ( + 'The columns in the computed data do not match the columns in the' + ' provided metadata' + in str(e) + ): self.fail('Column metadata is incorrect.') From c8f7cd505b086e68491b67b160ef5ed5c502c73f Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Sun, 24 Mar 2024 16:03:59 +0530 Subject: [PATCH 3/3] Reformatted with upgraded pyink. --- xarray_sql/df_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray_sql/df_test.py b/xarray_sql/df_test.py index 59953e6..aff1fc6 100644 --- a/xarray_sql/df_test.py +++ b/xarray_sql/df_test.py @@ -117,8 +117,7 @@ def test_column_metadata_preserved(self): except ValueError as e: if ( 'The columns in the computed data do not match the columns in the' - ' provided metadata' - in str(e) + ' provided metadata' in str(e) ): self.fail('Column metadata is incorrect.')