diff --git a/.github/workflows/master.yml b/.github/workflows/master.yml index 8eef5a3b..4da18ff8 100644 --- a/.github/workflows/master.yml +++ b/.github/workflows/master.yml @@ -20,7 +20,7 @@ jobs: if: always() strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.x"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] fail-fast: false steps: @@ -61,7 +61,7 @@ jobs: run: | sudo apt-get update && sudo apt-get install libhdf5-dev libnetcdf-dev python -m pip install --upgrade pip - pip install xarray~=0.18.0 pandas~=1.4.0 + pip install xarray~=2023.1.0 pandas~=1.4.0 - name: Install package run: | pip install -e .[tests] diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index e323fa86..823eafcf 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -31,7 +31,6 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - sudo apt-get update && sudo apt-get install libhdf5-dev libnetcdf-dev python -m pip install --upgrade pip - name: Install package run: | @@ -59,9 +58,8 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - sudo apt-get update && sudo apt-get install libhdf5-dev libnetcdf-dev python -m pip install --upgrade pip - pip install xarray~=0.18.0 pandas~=1.4.0 + pip install xarray~=2023.1.0 pandas~=1.4.0 - name: Install package run: | pip install -e .[tests] diff --git a/.github/workflows/ruff.yaml b/.github/workflows/ruff.yaml index 5c26f5dd..c453abbf 100644 --- a/.github/workflows/ruff.yaml +++ b/.github/workflows/ruff.yaml @@ -5,11 +5,11 @@ jobs: dependency-review: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: '3.x' - name: Install ruff run: pip install ruff - name: Run ruff - run: ruff xbout + run: ruff check xbout diff --git a/pyproject.toml b/pyproject.toml index 39e926e8..d5eade87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ classifiers = [ ] requires-python = ">=3.8" dependencies = [ - "xarray>=0.18.0,<2022.9.0", + "xarray>=2023.01.0", "boutdata>=0.1.4", "dask[array]>=2.10.0", "gelidum>=0.5.3", @@ -53,6 +53,9 @@ calc = [ "xrft", "xhistogram", ] +cherab = [ + "cherab", +] docs = [ "sphinx >= 5.3", "sphinx-book-theme >= 0.4.0rc1", @@ -79,5 +82,5 @@ write_to = "xbout/_version.py" [tool.setuptools] packages = ["xbout"] -[tool.ruff] +[tool.ruff.lint] ignore = ["E501"] diff --git a/xbout/boutdataarray.py b/xbout/boutdataarray.py index 2a2c5ed6..3e54f402 100644 --- a/xbout/boutdataarray.py +++ b/xbout/boutdataarray.py @@ -1101,3 +1101,71 @@ def plot3d(self, ax=None, **kwargs): See plotfuncs.plot3d() """ return plotfuncs.plot3d(self.data, **kwargs) + + def with_cherab_grid(self): + """ + Returns a new DataArray with a 'cherab_grid' attribute. + + If called then the `cherab` package must be available. + """ + # Import here so Cherab is required only if this method is called + from .cherab import grid + + return grid.da_with_cherab_grid(self.data) + + def as_cherab_data(self): + """ + Returns a new cherab.TriangularData object. + + If a Cherab grid has not been calculated then it will be created. + It is more efficient to first compute a Cherab grid for a whole + DataSet (using `with_cherab_grid`) and then call this function + on individual DataArrays. + """ + if "cherab_grid" not in self.data.attrs: + # Calculate the Cherab triangulation + da = self.with_cherab_grid() + else: + da = self.data + + return da.attrs["cherab_grid"].with_data(da) + + def as_cherab_emitter( + self, + parent=None, + cylinder_zmin=None, + cylinder_zmax=None, + cylinder_rmin=None, + cylinder_rmax=None, + step: float = 0.01, + ): + """ + Make a Cherab emitter (RadiationFunction), rotating a 2D mesh about the Z axis + + Cherab (https://www.cherab.info/) is a python library for forward + modelling diagnostics based on spectroscopic plasma emission. + It is based on the Raysect (http://www.raysect.org/) scientific + ray-tracing framework. + + Parameters + ---------- + parent : Cherab scene (default None) + The Cherab scene to attach the emitter to + step : float (default 0.01 meters) + Volume integration step length [m] + + Returns + ------- + + A cherab.tools.emitters.RadiationFunction + + """ + + return self.as_cherab_data().to_emitter( + parent=parent, + cylinder_zmin=cylinder_zmin, + cylinder_zmax=cylinder_zmax, + cylinder_rmin=cyliner_rmin, + cylinder_rmax=cyliner_rmax, + step=step, + ) diff --git a/xbout/boutdataset.py b/xbout/boutdataset.py index cc308fcb..94987b4b 100644 --- a/xbout/boutdataset.py +++ b/xbout/boutdataset.py @@ -597,12 +597,12 @@ def interpolate_to_cartesian( n_toroidal = ds.sizes[zdim] # Create Cartesian grid to interpolate to - Xmin = ds["X_cartesian"].min() - Xmax = ds["X_cartesian"].max() - Ymin = ds["Y_cartesian"].min() - Ymax = ds["Y_cartesian"].max() - Zmin = ds["Z_cartesian"].min() - Zmax = ds["Z_cartesian"].max() + Xmin = ds["X_cartesian"].min().data[()] + Xmax = ds["X_cartesian"].max().data[()] + Ymin = ds["Y_cartesian"].min().data[()] + Ymax = ds["Y_cartesian"].max().data[()] + Zmin = ds["Z_cartesian"].min().data[()] + Zmax = ds["Z_cartesian"].max().data[()] newX_1d = xr.DataArray(np.linspace(Xmin, Xmax, nX), dims="X") newX = newX_1d.expand_dims({"Y": nY, "Z": nZ}, axis=[1, 2]) newY_1d = xr.DataArray(np.linspace(Ymin, Ymax, nY), dims="Y") @@ -614,9 +614,7 @@ def interpolate_to_cartesian( # Define newzeta in range 0->2*pi newzeta = np.where(newzeta < 0.0, newzeta + 2.0 * np.pi, newzeta) - from scipy.interpolate import ( - RegularGridInterpolator, - ) + from scipy.interpolate import RegularGridInterpolator # Create Cylindrical coordinates for intermediate grid Rcyl_min = float_type(ds["R"].min()) @@ -664,10 +662,7 @@ def interp_single_time(da): ) print(" do 3d interpolation") - return interp( - (newR, newZ, newzeta), - method="linear", - ) + return interp((newR, newZ, newzeta), method="linear") for name, da in ds.data_vars.items(): print(f"\ninterpolating {name}") @@ -993,14 +988,7 @@ def to_restart( # Is this even possible without saving the guard cells? # Can they be recreated? restart_datasets, paths = _split_into_restarts( - self.data, - variables, - savepath, - nxpe, - nype, - tind, - prefix, - overwrite, + self.data, variables, savepath, nxpe, nype, tind, prefix, overwrite ) with ProgressBar(): @@ -1357,6 +1345,17 @@ def is_list(variable): return anim + def with_cherab_grid(self): + """ + Returns a new DataSet with a 'cherab_grid' attribute. + + If called then the `cherab` package must be available. + """ + # Import here so Cherab is required only if this method is called + from .cherab import grid + + return grid.ds_with_cherab_grid(self.data) + def _find_major_vars(data): """ diff --git a/xbout/cherab/__init__.py b/xbout/cherab/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/xbout/cherab/grid.py b/xbout/cherab/grid.py new file mode 100644 index 00000000..b74f5213 --- /dev/null +++ b/xbout/cherab/grid.py @@ -0,0 +1,93 @@ +import numpy as np +import xarray as xr + +from .triangulate import Triangulate + + +def da_with_cherab_grid(da): + """ + Convert an BOUT++ DataArray to a format that Cherab can use: + - A 'cell_number' coordinate + - A 'cherab_grid` attribute + + The cell_number coordinate enables the DataArray to be sliced + before input to Cherab. + + Parameters + ---------- + ds : xarray.DataArray + + Returns + ------- + updated_da + """ + if "cherab_grid" in da.attrs: + # Already has required attribute + return da + + if da.attrs["geometry"] != "toroidal": + raise ValueError("xhermes.plotting.cherab: Geometry must be toroidal") + + if da.sizes["zeta"] != 1: + raise ValueError("xhermes.plotting.cherab: Zeta index must have size 1") + + nx = da.sizes["x"] + ny = da.sizes["theta"] + + # Cell corners + rm = np.stack( + ( + da.coords["Rxy_upper_right_corners"], + da.coords["Rxy_upper_left_corners"], + da.coords["Rxy_lower_right_corners"], + da.coords["Rxy_lower_left_corners"], + ), + axis=-1, + ) + zm = np.stack( + ( + da.coords["Zxy_upper_right_corners"], + da.coords["Zxy_upper_left_corners"], + da.coords["Zxy_lower_right_corners"], + da.coords["Zxy_lower_left_corners"], + ), + axis=-1, + ) + + grid = Triangulate(rm, zm) + + # Store the cell number as a coordinate. + # This allows slicing of arrays before passing to Cherab + + # Create a DataArray for the vertices and triangles + cell_number = xr.DataArray( + grid.cell_number, dims=["x", "theta"], name="cell_number" + ) + + result = da.assign_coords(cell_number=cell_number) + result.attrs.update(cherab_grid=grid) + return result + + +def ds_with_cherab_grid(ds): + """ + Create an xarray DataSet with a Cherab grid + + Parameters + ---------- + ds : xarray.Dataset + + Returns + ------- + updated_ds + """ + + # The same operation works on a DataSet + ds = da_with_cherab_grid(ds) + + # Add the Cherab grid as an attribute to all variables + grid = ds.attrs["cherab_grid"] + for var in ds.data_vars: + ds[var].attrs.update(cherab_grid=grid) + + return ds diff --git a/xbout/cherab/triangulate.py b/xbout/cherab/triangulate.py new file mode 100644 index 00000000..a26f73a1 --- /dev/null +++ b/xbout/cherab/triangulate.py @@ -0,0 +1,231 @@ +""" +Interface to Cherab + +This module performs triangulation of BOUT++ grids, making them +suitable for input to Cherab ray-tracing analysis. + +""" + +import numpy as np +import xarray as xr + + +class TriangularData: + """ + Represents a set of triangles with data constant on them. + Creates a Cherab Discrete2DMesh, and can then convert + that to a 3D (axisymmetric) emitting material. + """ + + def __init__(self, vertices, triangles, data): + self.vertices = vertices + self.triangles = triangles + self.data = data + + from raysect.core.math.function.float import Discrete2DMesh + + self.mesh = Discrete2DMesh( + self.vertices, self.triangles, self.data, limit=False, default_value=0.0 + ) + + def to_emitter( + self, + parent=None, + cylinder_zmin=None, + cylinder_zmax=None, + cylinder_rmin=None, + cylinder_rmax=None, + step: float = 0.01, + ): + """ + Returns a 3D Cherab emitter, by rotating the 2D mesh about the Z axis + + step: Volume integration step length [m] + + """ + from raysect.core import translate + from raysect.primitive import Cylinder, Subtract + from raysect.optical.material import VolumeTransform + from cherab.core.math import AxisymmetricMapper + from cherab.tools.emitters import RadiationFunction + + if cylinder_zmin is None: + cylinder_zmin = np.amin(self.vertices[:, 1]) + if cylinder_zmax is None: + cylinder_zmax = np.amax(self.vertices[:, 1]) + if cylinder_rmin is None: + cylinder_rmin = np.amin(self.vertices[:, 0]) + if cylinder_rmax is None: + cylinder_rmax = np.amax(self.vertices[:, 0]) + + rad_function_3d = AxisymmetricMapper(self.mesh) + + shift = translate(0, 0, cylinder_zmin) + emitting_material = VolumeTransform( + RadiationFunction(rad_function_3d, step=step), shift.inverse() + ) + + # Create an annulus by removing the middle from the cylinder. + return Subtract( + Cylinder(cylinder_rmax, cylinder_zmax - cylinder_zmin), + Cylinder(cylinder_rmin, cylinder_zmax - cylinder_zmin), + transform=shift, + parent=parent, + material=emitting_material, + ) + + def plot_2d(self, ax=None, nr: int = 150, nz: int = 150): + """ + Make a 2D plot of the data + + nr, nz - Number of samples in R and Z + """ + import matplotlib.pyplot as plt + + if ax is None: + fig, ax = plt.subplots() + + Rmin, Zmin = np.amin(self.vertices, axis=0) + Rmax, Zmax = np.amax(self.vertices, axis=0) + + from cherab.core.math import sample2d + + r, z, emiss_sampled = sample2d(self.mesh, (Rmin, Rmax, nr), (Zmin, Zmax, nz)) + + image = ax.imshow( + emiss_sampled.T, origin="lower", extent=(r.min(), r.max(), z.min(), z.max()) + ) + fig.colorbar(image) + ax.set_xlabel("r") + ax.set_ylabel("z") + + return ax + + +class Triangulate: + """ + Represents a set of triangles for a 2D mesh in R-Z + + """ + + def __init__(self, rm, zm): + """ + rm and zm define quadrilateral cell corners in 2D (R, Z) + + rm : [nx, ny, 4] + zm : [nx, ny, 4] + """ + assert zm.shape == rm.shape + assert len(rm.shape) == 3 + nx, ny, n = rm.shape + assert n == 4 + + # Build a list of vertices and a list of triangles + vertices = [] + triangles = [] + + def vertex_index(R, Z): + """ + Return the vertex index at given (R,Z) location. + Note: This is inefficient linear search + """ + # Check if there is already a vertex at this location + for i, v in enumerate(vertices): + vr, vz = v + d2 = (vr - R) ** 2 + (vz - Z) ** 2 + if d2 < 1e-10: + return i + # Not found so add a new vertex + vertices.append((R, Z)) + return len(vertices) - 1 + + for ix in range(nx): + for jy in range(ny): + # Adding cell (ix, jy) + # Get the vertex indices of the 4 corners + vertex_inds = [ + vertex_index(rm[ix, jy, n], zm[ix, jy, n]) for n in range(4) + ] + # Choose corners so triangles have the same sign + triangles.append(vertex_inds[0:3]) # Corners 1,2,3 + triangles.append(vertex_inds[:0:-1]) # Corners 4,3,2 + + self.vertices = np.array(vertices) + self.triangles = np.array(triangles) + self.cell_number = np.arange(nx * ny).reshape((nx, ny)) + + def __repr__(self): + return "" + + def plot_triangles(self, ax=None): + import matplotlib.pyplot as plt + + if ax is None: + fig, ax = plt.subplots() + + rs = self.vertices[self.triangles, 0] + zs = self.vertices[self.triangles, 1] + + # Close the triangles + rs = np.concatenate((rs, rs[:, 0:1]), axis=1) + zs = np.concatenate((zs, zs[:, 0:1]), axis=1) + + ax.plot(rs.T, zs.T, "k") + + return ax + + def with_data(self, da): + """ + Returns a new object containing vertices, triangles, and data + + Parameters + ---------- + + da : xarray.DataArray + Expected to have 'cherab_grid' attribute + and 'cell_number' coordinate. + Should only have 'x' and 'theta' dimensions. + + Returns + ------- + + A TriangularData object + """ + + if "cherab_grid" not in da.attrs: + raise ValueError("DataArray missing cherab_grid attribute") + + if "cell_number" not in da.coords: + raise ValueError("DataArray missing cell_number coordinate") + + da = da.squeeze() # Drop dimensions of size 1 + + # Check that extra dimensions (e.g time) have been dropped + # so that the data has the same dimensions as cell_number + if da.sizes != da.coords["cell_number"].sizes: + raise ValueError( + f"Data and cell_number coordinate have " + f"different sizes ({da.sizes} and " + f"{da.coords['cell_number'].sizes})" + ) + + if 2 * da.size == self.triangles.shape[0]: + # Data has not been sliced, so the size matches + # the number of triangles + + # Note: Two triangles per quad, so repeat data twice + return TriangularData( + self.vertices, self.triangles, da.data.flatten().repeat(2) + ) + + # Data size and number of triangles don't match. + # Use cell_number to work out which triangles to keep + + cells = da.coords["cell_number"].data.flatten() + triangles = np.concatenate( + (self.triangles[cells * 2, :], self.triangles[cells * 2 + 1, :]) + ) + + data = np.tile(da.data.flatten(), 2) + + return TriangularData(self.vertices, triangles, data) diff --git a/xbout/geometries.py b/xbout/geometries.py index fb50b432..dd1b6692 100644 --- a/xbout/geometries.py +++ b/xbout/geometries.py @@ -10,6 +10,7 @@ _set_attrs_on_all_vars, _set_as_coord, _1d_coord_from_spacing, + _maybe_rename_dimension, ) REGISTERED_GEOMETRIES = {} @@ -144,7 +145,7 @@ def apply_geometry(ds, geometry_name, *, coordinates=None, grid=None): # 'dx' may not be consistent between different regions (e.g. core and PFR). # For some geometries xcoord may have already been created by # add_geometry_coords, in which case we do not need this. - nx = updated_ds.dims[xcoord] + nx = updated_ds.sizes[xcoord] # can't use commented out version, uncommented one works around xarray bug # removing attrs @@ -181,7 +182,7 @@ def apply_geometry(ds, geometry_name, *, coordinates=None, grid=None): if zcoord in updated_ds.dims and zcoord not in updated_ds.coords: # Generates a coordinate whose value is 0 on the first grid point, not dz/2, to # match how BOUT++ generates fields from input file expressions. - nz = updated_ds.dims[zcoord] + nz = updated_ds.sizes[zcoord] # In BOUT++ v5, dz is either a Field2D or Field3D. # We can use it as a 1D coordinate if it's a Field3D, _or_ if nz == 1 @@ -213,7 +214,7 @@ def apply_geometry(ds, geometry_name, *, coordinates=None, grid=None): dz = updated_ds["dz"] z0 = 2 * np.pi * updated_ds.metadata["ZMIN"] - z1 = z0 + nz * dz + z1 = z0 + nz * dz.data[()] if not np.all( np.isclose( z1, @@ -392,12 +393,12 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None): ], ) - if "t" in ds.dims: + if coordinates["t"] != "t": # Rename 't' if user requested it - ds = ds.rename(t=coordinates["t"]) + ds = _maybe_rename_dimension(ds, "t", coordinates["t"]) # Change names of dimensions to Orthogonal Toroidal ones - ds = ds.rename(y=coordinates["y"]) + ds = _maybe_rename_dimension(ds, "y", coordinates["y"]) # TODO automatically make this coordinate 1D in simplified cases? ds[coordinates["x"]] = ds["psixy"] @@ -413,7 +414,7 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None): # If full data (not just grid file) then toroidal dim will be present if "z" in ds.dims: - ds = ds.rename(z=coordinates["z"]) + ds = _maybe_rename_dimension(ds, "z", coordinates["z"]) # Record which dimension 'z' was renamed to. ds.metadata["bout_zdim"] = coordinates["z"] @@ -505,7 +506,7 @@ def add_s_alpha_geometry_coords(ds, *, coordinates=None, grid=None): ds["r"] = ds["hthe"].isel({ycoord: 0}).squeeze(drop=True) ds["r"].attrs["units"] = "m" ds = ds.set_coords("r") - ds = ds.rename(x="r") + ds = ds.swap_dims(x="r") ds.metadata["bout_xdim"] = "r" if hthe_from_grid: diff --git a/xbout/load.py b/xbout/load.py index 33a4b8f7..3598653f 100644 --- a/xbout/load.py +++ b/xbout/load.py @@ -6,7 +6,6 @@ from boutdata.data import BoutOptionsFile import xarray as xr -from numpy import unique from natsort import natsorted @@ -19,22 +18,6 @@ _is_dir, ) - -_BOUT_PER_PROC_VARIABLES = [ - "wall_time", - "wtime", - "wtime_rhs", - "wtime_invert", - "wtime_comms", - "wtime_io", - "wtime_per_rhs", - "wtime_per_rhs_e", - "wtime_per_rhs_i", - "PE_XIND", - "PE_YIND", - "MYPE", -] -_BOUT_TIME_DEPENDENT_META_VARS = ["iteration", "hist_hi", "tt"] _BOUT_GEOMETRY_VARS = [ "ixseps1", "ixseps2", @@ -69,15 +52,12 @@ ) -# TODO somehow check that we have access to the latest version of auto_combine - - def open_boutdataset( datapath="./BOUT.dmp.*.nc", inputfilepath=None, geometry=None, gridfilepath=None, - grid_mismatch="raise", #: Union[Literal["raise"], Literal["warn"], Literal["ignore"]] + grid_mismatch="raise", chunks=None, keep_xboundaries=True, keep_yboundaries=False, @@ -87,48 +67,55 @@ def open_boutdataset( is_mms_dump=False, **kwargs, ): - """ - Load a dataset from a set of BOUT output files, including the input options - file. Can also load from a grid file or from restart files. - - Note that when reloading a Dataset that was saved by xBOUT, the state of the saved - Dataset is restored, and the values of ``keep_xboundaries``, ``keep_yboundaries``, and - ``run_name`` are ignored. ``geometry`` is treated specially, and can be passed when + """Load a dataset from a set of BOUT output files, including the + input options file. Can also load from a grid file or from restart + files. + + Note that when reloading a Dataset that was saved by xBOUT, the + state of the saved Dataset is restored, and the values of + ``keep_xboundaries``, ``keep_yboundaries``, and ``run_name`` are + ignored. ``geometry`` is treated specially, and can be passed when reloading a Dataset (along with ``gridfilepath`` if needed). Troubleshooting --------------- - Variable conflicts: sometimes, for example when loading data from multiple restarts, - some variables may have conflicts (e.g. a source term was changed between some of - the restarts, but the source term is saved as time-independent, without a - t-dimension). In this case one workaround is to pass a list of variable names to the - keyword argument ``drop_vars`` to ignore the variables with conflicts, e.g. if ``"S1"`` - and ``"S2"`` have conflicts:: + Variable conflicts: sometimes, for example when loading data from + multiple restarts, some variables may have conflicts (e.g. a + source term was changed between some of the restarts, but the + source term is saved as time-independent, without a + t-dimension). In this case one workaround is to pass a list of + variable names to the keyword argument ``drop_vars`` to ignore the + variables with conflicts, e.g. if ``"S1"`` and ``"S2"`` have + conflicts:: ds = open_boutdataset("data*/boutdata.nc", drop_variables=["S1", "S2"]) will open a Dataset which is missing ``"S1"`` and ``"S2"`` - (``drop_variables`` is an argument of `xarray.open_dataset` that is passed down - through ``kwargs``.) + (``drop_variables`` is an argument of `xarray.open_dataset` that + is passed down through ``kwargs``.) Parameters ---------- - datapath : str or (list or tuple of xr.Dataset), optional - Path to the data to open. Can point to either a set of one or more dump - files, or a single grid file. - To specify multiple dump files you must enter the path to them as a - single glob, e.g. './BOUT.dmp.*.nc', or for multiple consecutive runs - in different directories (in order) then './run*/BOUT.dmp.*.nc'. + datapath : str or (list or tuple of xr.Dataset), optional Path to + the data to open. Can point to either a set of one or more + dump files, or a single grid file. + + To specify multiple dump files you must enter the path to them + as a single glob, e.g. './BOUT.dmp.*.nc', or for multiple + consecutive runs in different directories (in order) then + './run*/BOUT.dmp.*.nc'. + + If a list or tuple of xr.Dataset is passed, they will be + combined with xr.combine_nested() instead of loading data from + disk (intended for unit testing). - If a list or tuple of xr.Dataset is passed, they will be combined with - xr.combine_nested() instead of loading data from disk (intended for unit - testing). chunks : dict, optional inputfilepath : str, optional geometry : str, optional - The geometry type of the grid data. This will specify what type of - coordinates to add to the dataset, e.g. 'toroidal' or 'cylindrical'. + The geometry type of the grid data. This will specify what + type of coordinates to add to the dataset, e.g. 'toroidal' or + 'cylindrical'. If not specified then will attempt to read it from the file attrs. If still not found then a warning will be thrown, which can be @@ -138,43 +125,55 @@ def open_boutdataset( `register_geometry` decorator. You are encouraged to do this for your own BOUT++ physics module, to apply relevant normalisations. + gridfilepath : str, optional - The path to a grid file, containing any variables needed to apply the geometry - specified by the 'geometry' option, which are not contained in the dump files. - This may either be the path of the grid file itself, or the directory - relative to which the grid from the settings file can be found. + The path to a grid file, containing any variables needed to + apply the geometry specified by the 'geometry' option, which + are not contained in the dump files. This may either be the + path of the grid file itself, or the directory relative to + which the grid from the settings file can be found. + grid_mismatch : str, optional - How to handle if the grid is not the grid that has been used for the - simulation. Can be "raise" to raise a RuntimeError, "warn" to raise a - warning, or ignore to ignore the mismatch silently. + How to handle if the grid is not the grid that has been used + for the simulation. Can be "raise" to raise a RuntimeError, + "warn" to raise a warning, or ignore to ignore the mismatch + silently. + keep_xboundaries : bool, optional - If true, keep x-direction boundary cells (the cells past the physical - edges of the grid, where boundary conditions are set); increases the - size of the x dimension in the returned data-set. If false, trim these - cells. + If true, keep x-direction boundary cells (the cells past the + physical edges of the grid, where boundary conditions are + set); increases the size of the x dimension in the returned + data-set. If false, trim these cells. + keep_yboundaries : bool, optional - If true, keep y-direction boundary cells (the cells past the physical - edges of the grid, where boundary conditions are set); increases the - size of the y dimension in the returned data-set. If false, trim these - cells. + If true, keep y-direction boundary cells (the cells past the + physical edges of the grid, where boundary conditions are + set); increases the size of the y dimension in the returned + data-set. If false, trim these cells. + run_name : str, optional - Name to give to the whole dataset, e.g. 'JET_ELM_high_resolution'. - Useful if you are going to open multiple simulations and compare the - results. + Name to give to the whole dataset, + e.g. 'JET_ELM_high_resolution'. Useful if you are going to + open multiple simulations and compare the results. + info : bool or "terse", optional is_restart : bool, optional - Restart files require some special handling (e.g. working around variables that - are not present in restart files). By default, this special handling is enabled - if the files do not have a time dimension and ``restart`` is present in the file - name in ``datapath``. This option can be set to True or False to explicitly enable - or disable the restart file handling. + Restart files require some special handling (e.g. working + around variables that are not present in restart files). By + default, this special handling is enabled if the files do not + have a time dimension and ``restart`` is present in the file + name in ``datapath``. This option can be set to True or False + to explicitly enable or disable the restart file handling. + kwargs : optional - Keyword arguments are passed down to `xarray.open_mfdataset`, which in - turn passes extra kwargs down to `xarray.open_dataset`. + Keyword arguments are passed down to `xarray.open_mfdataset`, + which in turn passes extra kwargs down to + `xarray.open_dataset`. Returns ------- ds : xarray.Dataset + """ if chunks is None: @@ -190,8 +189,8 @@ def open_boutdataset( if "reload" in input_type: if input_type == "reload": if isinstance(datapath, Path): - # xr.open_mfdataset only accepts glob patterns as strings, not Path - # objects + # xr.open_mfdataset only accepts glob patterns as + # strings, not Path objects datapath = str(datapath) ds = xr.open_mfdataset( datapath, @@ -297,15 +296,6 @@ def attrs_remove_section(obj, section): else: raise ValueError(f"internal error: unexpected input_type={input_type}") - if not is_restart: - for var in _BOUT_TIME_DEPENDENT_META_VARS: - if var in ds: - # Assume different processors in x & y have same iteration etc. - latest_top_left = {dim: 0 for dim in ds[var].dims} - if "t" in ds[var].dims: - latest_top_left["t"] = -1 - ds[var] = ds[var].isel(latest_top_left).squeeze(drop=True) - ds, metadata = _separate_metadata(ds) # Store as ints because netCDF doesn't support bools, so we can't save # bool attributes @@ -333,7 +323,8 @@ def attrs_remove_section(obj, section): gridfilepath += "/" + ds.options["grid"] else: warn( - "gridfilepath set to a directory, but no grid used in simulation. Continuing without grid." + "gridfilepath set to a directory, but no grid used " + "in simulation. Continuing without grid." ) if gridfilepath is not None: grid = _open_grid( @@ -383,15 +374,15 @@ def attrs_remove_section(obj, section): if run_name: ds.name = run_name - # Set some default settings that are only used in post-processing by xBOUT, not by - # BOUT++ + # Set some default settings that are only used in post-processing + # by xBOUT, not by BOUT++ ds.bout.fine_interpolation_factor = 8 if ("dump" in input_type or "restart" in input_type) and ds.metadata[ "BOUT_VERSION" ] < 4.0: - # Add workarounds for missing information or different conventions in data saved - # by BOUT++ v3.x. + # Add workarounds for missing information or different + # conventions in data saved by BOUT++ v3.x. for v in ds: if ds.metadata["bout_zdim"] in ds[v].dims: # All fields saved on aligned grid for BOUT-3 @@ -406,21 +397,22 @@ def attrs_remove_section(obj, section): ds.metadata["bout_zdim"], ) ): - # zShift, etc. did not support staggered grids in BOUT++ v3 anyway, so - # just treat all variables as if they were at CELL_CENTRE + # zShift, etc. did not support staggered grids in + # BOUT++ v3 anyway, so just treat all variables as if + # they were at CELL_CENTRE ds[v].attrs["cell_location"] = "CELL_CENTRE" added_location = True if added_location: warn( - "Detected data from BOUT++ v3.x. Treating all variables as being " - "at `CELL_CENTRE`. Should be similar to what BOUT++ v3.x did, but " - "if your code uses staggered grids, this may produce unexpected " - "effects in some places." + "Detected data from BOUT++ v3.x. Treating all variables" + " as being at `CELL_CENTRE`. Should be similar to what" + " BOUT++ v3.x did, but if your code uses staggered grids," + " this may produce unexpected effects in some places." ) if "nz" not in ds.metadata: - # `nz` used to be stored as `MZ` and `MZ` used to include an extra buffer - # point that was not used for data. + # `nz` used to be stored as `MZ` and `MZ` used to include + # an extra buffer point that was not used for data. ds.metadata["nz"] = ds.metadata["MZ"] - 1 if info == "terse": @@ -458,9 +450,7 @@ def collect( info=True, prefix="BOUT.dmp", ): - """ - - Extract the data pertaining to a specified variable in a BOUT++ data set + """Extract the data pertaining to a specified variable in a BOUT++ data set Parameters @@ -486,11 +476,14 @@ def collect( Notes ---------- - strict : This option found in boutdata.collect() is not present in this function - it is assumed that the varname given is correct, if variable does not exist - the function will fail - tind_auto : This option is not required when using _auto_open_mfboutdataset as an - automatic failure if datasets are different lengths is included + strict : This option found in boutdata.collect() is not present in + this function it is assumed that the varname given is + correct, if variable does not exist the function will + fail + + tind_auto : This option is not required when using + _auto_open_mfboutdataset as an automatic failure if + datasets are different lengths is included Returns ---------- @@ -622,11 +615,6 @@ def _auto_open_mfboutdataset( if chunks is None: chunks = {} - if is_restart: - data_vars = "minimal" - else: - data_vars = _BOUT_TIME_DEPENDENT_META_VARS - if _is_path(datapath): filepaths, filetype = _expand_filepaths(datapath) @@ -635,16 +623,20 @@ def _auto_open_mfboutdataset( filepaths[0], info, keep_yboundaries ) if is_squashed_doublenull: - # Need to remove y-boundaries after loading: (i) in case we are loading a - # squashed data-set, in which case we cannot easily remove the upper - # boundary cells in _trim(); (ii) because using the remove_yboundaries() - # method for non-squashed data-sets is simpler than replicating that logic - # in _trim(). + # Need to remove y-boundaries after loading: (i) in case + # we are loading a squashed data-set, in which case we + # cannot easily remove the upper boundary cells in + # _trim(); (ii) because using the remove_yboundaries() + # method for non-squashed data-sets is simpler than + # replicating that logic in _trim(). remove_yboundaries = not keep_yboundaries keep_yboundaries = True else: remove_yboundaries = False + # Create a partial application of _trim + # Calls to _preprocess will call _trim to trim guard / boundary cells + # from datasets before merging. _preprocess = partial( _trim, guards={"x": mxg, "y": myg}, @@ -656,40 +648,28 @@ def _auto_open_mfboutdataset( paths_grid, concat_dims = _arrange_for_concatenation(filepaths, nxpe, nype) - try: - ds = xr.open_mfdataset( - paths_grid, - concat_dim=concat_dims, - combine="nested", - data_vars=data_vars, - preprocess=_preprocess, - engine=filetype, - chunks=chunks, - join="exact", - **kwargs, - ) - except ValueError as e: - message_to_catch = ( - "some variables in data_vars are not data variables on the first " - "dataset:" - ) - if str(e)[: len(message_to_catch)] == message_to_catch: - # Open concatenating any variables that are different in - # different files as a work around to support opening older - # data. - ds = xr.open_mfdataset( - paths_grid, - concat_dim=concat_dims, - combine="nested", - data_vars="different", - preprocess=_preprocess, - engine=filetype, - chunks=chunks, - join="exact", - **kwargs, - ) - else: - raise + ds = xr.open_mfdataset( + paths_grid, + concat_dim=concat_dims, + combine="nested", + preprocess=_preprocess, + engine=filetype, + chunks=chunks, + # Only data variables in which the dimension already + # appears are concatenated. + data_vars="minimal", + # Only coordinates in which the dimension already appears + # are concatenated. + coords="minimal", + # Duplicate data taken from first dataset + compat="override", + # Duplicate attributes taken from first dataset + combine_attrs="override", + # Don't align. Raise ValueError when indexes to be aligned + # are not equal + join="exact", + **kwargs, + ) else: # datapath was nested list of Datasets @@ -707,9 +687,9 @@ def _auto_open_mfboutdataset( ) if is_squashed_doublenull: - # Need to remove y-boundaries after loading when loading a squashed - # data-set, in which case we cannot easily remove the upper boundary cells - # in _trim(). + # Need to remove y-boundaries after loading when loading a + # squashed data-set, in which case we cannot easily remove + # the upper boundary cells in _trim(). remove_yboundaries = not keep_yboundaries keep_yboundaries = True else: @@ -731,9 +711,17 @@ def _auto_open_mfboutdataset( ds = xr.combine_nested( ds_grid, concat_dim=concat_dims, - data_vars=data_vars, join="exact", - combine_attrs="no_conflicts", + # Only data variables in which the dimension already + # appears are concatenated. + data_vars="minimal", + # Only coordinates in which the dimension already appears + # are concatenated. + coords="minimal", + # Duplicate data taken from first dataset + compat="override", + # Duplicate attributes taken from first dataset + combine_attrs="override", ) if not is_restart: @@ -779,7 +767,8 @@ def _expand_wildcards(path): # Find path relative to parent search_pattern = str(path.relative_to(base_dir)) - # Search this relative path from the parent directory for all files matching user input + # Search this relative path from the parent directory + # for all files matching user input filepaths = list(base_dir.glob(search_pattern)) # Sort by numbers in filepath before returning @@ -807,7 +796,7 @@ def get_nonnegative_scalar(ds, key, default=1, info=True): print(f"{key} not found, setting to {default}") if default < 0: raise ValueError( - f"Default for {key} is {val}, but negative values are not valid" + f"Default for {key} is {val}," f" but negative values are not valid" ) return default @@ -816,14 +805,15 @@ def get_nonnegative_scalar(ds, key, default=1, info=True): mxg = get_nonnegative_scalar(ds, "MXG", default=2, info=info) myg = get_nonnegative_scalar(ds, "MYG", default=0, info=info) mxsub = get_nonnegative_scalar( - ds, "MXSUB", default=ds.dims["x"] - 2 * mxg, info=info + ds, "MXSUB", default=ds.sizes["x"] - 2 * mxg, info=info ) mysub = get_nonnegative_scalar( - ds, "MYSUB", default=ds.dims["y"] - 2 * myg, info=info + ds, "MYSUB", default=ds.sizes["y"] - 2 * myg, info=info ) - # Check whether this is a single file squashed from the multiple output files of a - # parallel run (i.e. NXPE*NYPE > 1 even though there is only a single file to read). + # Check whether this is a single file squashed from the multiple + # output files of a parallel run (i.e. NXPE*NYPE > 1 even though + # there is only a single file to read). if "nx" in ds: nx = ds["nx"].values else: @@ -834,12 +824,12 @@ def get_nonnegative_scalar(ds, key, default=1, info=True): else: # Workaround for older data files ny = ds["MYSUB"].values * ds["NYPE"].values - nx_file = ds.dims["x"] - ny_file = ds.dims["y"] + nx_file = ds.sizes["x"] + ny_file = ds.sizes["y"] is_squashed_doublenull = False if nxpe > 1 or nype > 1: - # if nxpe = nype = 1, was only one process anyway, so no need to check for - # squashing + # if nxpe = nype = 1, was only one process anyway, so no need + # to check for squashing if nx_file == nx or nx_file == nx - 2 * mxg: has_xboundaries = nx_file == nx if not has_xboundaries: @@ -854,7 +844,8 @@ def get_nonnegative_scalar(ds, key, default=1, info=True): else: upper_target_cells = 0 if ny_file == ny or ny_file == ny + 2 * myg + 2 * upper_target_cells: - # This file contains all the points, possibly including guard cells + # This file contains all the points, possibly + # including guard cells has_yboundaries = not (ny_file == ny) if not has_yboundaries: @@ -870,8 +861,8 @@ def get_nonnegative_scalar(ds, key, default=1, info=True): # squashed with upper target points. is_squashed_doublenull = False elif ny_file == ny + 2 * myg: - # Older squashed file from double-null grid but containing only lower - # target boundary cells. + # Older squashed file from double-null grid but + # containing only lower target boundary cells. if keep_yboundaries: raise ValueError( "Cannot keep y-boundary points: squashed file is missing upper " @@ -883,8 +874,9 @@ def get_nonnegative_scalar(ds, key, default=1, info=True): nxpe = 1 nype = 1 - # For this case, do not need the special handling enabled by - # is_squashed_doublenull=True, as keeping y-boundaries is not allowed + # For this case, do not need the special handling + # enabled by is_squashed_doublenull=True, as keeping + # y-boundaries is not allowed is_squashed_doublenull = False # Avoid trying to open this file twice @@ -933,7 +925,7 @@ def getrunid(fp): "`BOUT.dmp.0.nc`." ) raise ValueError( - f"A parallel simulation was loaded, but only {len(filepathts)} " + f"A parallel simulation was loaded, but only {len(filepaths)} " "files were loaded. Please ensure to pass in all files " "by specifing e.g. `BOUT.dmp.*.nc`" ) @@ -945,8 +937,8 @@ def getrunid(fp): "load each directory separately and concatenate them " "along the time dimension with xarray.concat()." ) - # Create list of lists of filepaths, so that xarray knows how they should - # be concatenated by xarray.open_mfdataset() + # Create list of lists of filepaths, so that xarray knows how + # they should be concatenated by xarray.open_mfdataset() paths = iter(filepaths) paths_grid = [ [[next(paths) for x in range(nxpe)] for y in range(nype)] @@ -968,13 +960,15 @@ def getrunid(fp): if tmp["PE_XIND"] != 0 or tmp["PE_YIND"] != 0: # The first file is missing. warn( - f"Ignoring {len(paths)} files as the first seems to be missing: {paths}" + f"Ignoring {len(paths)} files as the first" + f" seems to be missing: {paths}" ) continue assert tmp["NXPE"] == nxpe assert tmp["NYPE"] == nype raise ValueError( - f"Something is wrong. We expected {nprocs} files but found {len(paths)} files." + f"Something is wrong. We expected {nprocs} files" + f" but found {len(paths)} files." ) paths = iter(paths) @@ -994,11 +988,13 @@ def getrunid(fp): def _trim(ds, *, guards, keep_boundaries, nxpe, nype, is_restart): - """ - Trims all guard (and optionally boundary) cells off a single dataset read from a - single BOUT dump file, to prepare for concatenation. - Also drops some variables that store timing information, which are different for each - process and so cannot be concatenated. + """Trims all guard (and optionally boundary) cells off a single + dataset read from a single BOUT dump file, to prepare for + concatenation. + + Variables that store timing information, which are different for + each process, are not trimmed but are taken from the first + processor during concatenation. Parameters ---------- @@ -1013,6 +1009,7 @@ def _trim(ds, *, guards, keep_boundaries, nxpe, nype, is_restart): Number of processors in y direction is_restart : bool Is data being loaded from restart files? + """ if any(keep_boundaries.values()): @@ -1037,15 +1034,14 @@ def _trim(ds, *, guards, keep_boundaries, nxpe, nype, is_restart): ): trimmed_ds = trimmed_ds.drop_vars(name) - to_drop = _BOUT_PER_PROC_VARIABLES - - return trimmed_ds.drop_vars(to_drop, errors="ignore") + return trimmed_ds def _infer_contains_boundaries(ds, nxpe, nype): - """ - Uses the processor indices and BOUT++'s topology indices to work out whether this - dataset contains boundary cells, and on which side. + """Uses the processor indices and BOUT++'s topology indices to + work out whether this dataset contains boundary cells, and on + which side. + """ if nxpe * nype == 1: @@ -1057,8 +1053,8 @@ def _infer_contains_boundaries(ds, nxpe, nype): yproc = int(ds["PE_YIND"]) except KeyError: # output file from BOUT++ earlier than 4.3 - # Use knowledge that BOUT names its output files as /folder/prefix.num.nc, with a - # numbering scheme + # Use knowledge that BOUT names its output files as + # /folder/prefix.num.nc, with a numbering scheme # num = nxpe*i + j, where i={0, ..., nype}, j={0, ..., nxpe} filename = ds.encoding["source"] *prefix, filenum, extension = Path(filename).suffixes @@ -1115,9 +1111,10 @@ def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2, **kw acceptable_dims = ["x", "y", "z"] - # Passing 'chunks' with dimensions that are not present in the dataset causes an - # error. A gridfile will be missing 't' and may be missing 'z' dimensions that dump - # files have, so we must remove them from 'chunks'. + # Passing 'chunks' with dimensions that are not present in the + # dataset causes an error. A gridfile will be missing 't' and may + # be missing 'z' dimensions that dump files have, so we must + # remove them from 'chunks'. grid_chunks = copy(chunks) unrecognised_chunk_dims = list(set(grid_chunks.keys()) - set(acceptable_dims)) for dim in unrecognised_chunk_dims: @@ -1131,8 +1128,6 @@ def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2, **kw else: grid = datapath - # TODO find out what 'yup_xsplit' etc are in the doublenull storm file John gave me - # For now drop any variables with extra dimensions unrecognised_dims = list(set(grid.dims) - set(acceptable_dims)) if len(unrecognised_dims) > 0: # Weird string formatting is a workaround to deal with possible bug in @@ -1144,15 +1139,15 @@ def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2, **kw grid = grid.drop_dims(unrecognised_dims) if keep_xboundaries: - # Set MXG so that it is picked up in metadata - needed for applying geometry, - # etc. + # Set MXG so that it is picked up in metadata - needed for + # applying geometry, etc. grid["MXG"] = mxg else: xboundaries = mxg if xboundaries > 0: grid = grid.isel(x=slice(xboundaries, -xboundaries, None)) - # Set MXG so that it is picked up in metadata - needed for applying geometry, - # etc. + # Set MXG so that it is picked up in metadata - needed for + # applying geometry, etc. grid["MXG"] = 0 try: yboundaries = int(grid["y_boundary_guards"]) @@ -1161,16 +1156,16 @@ def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2, **kw # never had y-boundary cells yboundaries = 0 if keep_yboundaries: - # Set MYG so that it is picked up in metadata - needed for applying geometry, - # etc. + # Set MYG so that it is picked up in metadata + # - needed for applying geometry, etc. grid["MYG"] = yboundaries else: if yboundaries > 0: # Remove y-boundary cells from first divertor target grid = grid.isel(y=slice(yboundaries, -yboundaries, None)) if grid["jyseps1_2"] > grid["jyseps2_1"]: - # There is a second divertor target, remove y-boundary cells - # there too + # There is a second divertor target, remove y-boundary + # cells there too nin = int(grid["ny_inner"]) grid_lower = grid.isel(y=slice(None, nin, None)) grid_upper = grid.isel(y=slice(nin + 2 * yboundaries, None, None)) @@ -1181,8 +1176,8 @@ def _open_grid(datapath, chunks, keep_xboundaries, keep_yboundaries, mxg=2, **kw compat="identical", join="exact", ) - # Set MYG so that it is picked up in metadata - needed for applying geometry, - # etc. + # Set MYG so that it is picked up in metadata + # - needed for applying geometry, etc. grid["MYG"] = 0 if "z" in grid_chunks and "z" not in grid.dims: diff --git a/xbout/plotting/plotfuncs.py b/xbout/plotting/plotfuncs.py index db0a1e67..0f9dc51f 100644 --- a/xbout/plotting/plotfuncs.py +++ b/xbout/plotting/plotfuncs.py @@ -584,10 +584,7 @@ def plot3d( fill_value=datamin - 2.0 * (datamax - datamin), ) print("do 3d interpolation") - grid = interp( - (newzeta, newR, newZ), - method="linear", - ) + grid = interp((newzeta, newR, newZ), method="linear") print("done interpolating") if style == "isosurface": @@ -932,10 +929,10 @@ def plot2d_polygon( if vmax is None: vmax = np.nanmax(da.max().values) - if colorbar_label == None: + if colorbar_label is None: if "short_name" in da.attrs: colorbar_label = da.attrs["short_name"] - elif da.name != None: + elif da.name is not None: colorbar_label = da.name else: colorbar_label = "" diff --git a/xbout/region.py b/xbout/region.py index d30d4047..4db156a8 100644 --- a/xbout/region.py +++ b/xbout/region.py @@ -124,8 +124,8 @@ def __init__( ref_yind = ylower_ind dx = ds["dx"].isel({self.ycoord: ref_yind}) dx_cumsum = dx.cumsum() - self.xinner = dx_cumsum[xinner_ind] - dx[xinner_ind] - self.xouter = dx_cumsum[xouter_ind - 1] + dx[xouter_ind - 1] + self.xinner = (dx_cumsum[xinner_ind] - dx[xinner_ind]).values + self.xouter = (dx_cumsum[xouter_ind - 1] + dx[xouter_ind - 1]).values # dy is constant in the x-direction, so convert to a 1d array # Define ref_xind so that we avoid using values from the corner cells, which @@ -136,8 +136,8 @@ def __init__( ref_xind = xinner_ind dy = ds["dy"].isel(**{self.xcoord: ref_xind}) dy_cumsum = dy.cumsum() - self.ylower = dy_cumsum[ylower_ind] - dy[ylower_ind] - self.yupper = dy_cumsum[yupper_ind - 1] + self.ylower = (dy_cumsum[ylower_ind] - dy[ylower_ind]).values + self.yupper = (dy_cumsum[yupper_ind - 1]).values def __repr__(self): result = "\n" @@ -1348,7 +1348,9 @@ def _concat_inner_guards(da, da_global, mxg): # https://github.com/pydata/xarray/issues/4393 # da_inner = da_inner.assign_coords(**{xcoord: new_xcoord, ycoord: new_ycoord}) da_inner[xcoord].data[...] = new_xcoord.data + da_inner = da_inner.reset_index(xcoord).set_xindex(xcoord) da_inner[ycoord].data[...] = new_ycoord.data + da_inner = da_inner.reset_index(ycoord).set_xindex(ycoord) save_regions = da.bout._regions da = xr.concat((da_inner, da), xcoord, join="exact") @@ -1459,7 +1461,9 @@ def _concat_outer_guards(da, da_global, mxg): # https://github.com/pydata/xarray/issues/4393 # da_outer = da_outer.assign_coords(**{xcoord: new_xcoord, ycoord: new_ycoord}) da_outer[xcoord].data[...] = new_xcoord.data + da_outer = da_outer.reset_index(xcoord).set_xindex(xcoord) da_outer[ycoord].data[...] = new_ycoord.data + da_outer = da_outer.reset_index(ycoord).set_xindex(ycoord) save_regions = da.bout._regions da = xr.concat((da, da_outer), xcoord, join="exact") @@ -1559,7 +1563,9 @@ def _concat_lower_guards(da, da_global, mxg, myg): # https://github.com/pydata/xarray/issues/4393 # da_lower = da_lower.assign_coords(**{xcoord: new_xcoord, ycoord: new_ycoord}) da_lower[xcoord].data[...] = new_xcoord.data + da_lower = da_lower.reset_index(xcoord).set_xindex(xcoord) da_lower[ycoord].data[...] = new_ycoord.data + da_lower = da_lower.reset_index(ycoord).set_xindex(ycoord) if "poloidal_distance" in da.coords and myg > 0: # Special handling for core regions to deal with branch cut @@ -1675,7 +1681,9 @@ def _concat_upper_guards(da, da_global, mxg, myg): # https://github.com/pydata/xarray/issues/4393 # da_upper = da_upper.assign_coords(**{xcoord: new_xcoord, ycoord: new_ycoord}) da_upper[xcoord].data[...] = new_xcoord.data + da_upper = da_upper.reset_index(xcoord).set_xindex(xcoord) da_upper[ycoord].data[...] = new_ycoord.data + da_upper = da_upper.reset_index(ycoord).set_xindex(ycoord) if "poloidal_distance" in da.coords and myg > 0: # Special handling for core regions to deal with branch cut diff --git a/xbout/tests/test_boutdataset.py b/xbout/tests/test_boutdataset.py index 2d97ecb0..4ef8c6fe 100644 --- a/xbout/tests/test_boutdataset.py +++ b/xbout/tests/test_boutdataset.py @@ -42,7 +42,7 @@ def test_concat(self, bout_xyt_example_files): datapath=dataset_list2, inputfilepath=None, keep_xboundaries=False ) result = concat([bd1, bd2], dim="run") - assert result.dims == {**bd1.dims, "run": 2} + assert result.sizes == {**bd1.sizes, "run": 2} def test_isel(self, bout_xyt_example_files): dataset_list = bout_xyt_example_files(None, nxpe=1, nype=1, nt=1) diff --git a/xbout/tests/test_load.py b/xbout/tests/test_load.py index bb4c917e..20ddcae0 100644 --- a/xbout/tests/test_load.py +++ b/xbout/tests/test_load.py @@ -16,12 +16,25 @@ _trim, _infer_contains_boundaries, open_boutdataset, - _BOUT_PER_PROC_VARIABLES, - _BOUT_TIME_DEPENDENT_META_VARS, ) from xbout.utils import _separate_metadata from xbout.tests.utils_for_tests import create_bout_ds, METADATA_VARS +_BOUT_PER_PROC_VARIABLES = [ + "wall_time", + "wtime", + "wtime_rhs", + "wtime_invert", + "wtime_comms", + "wtime_io", + "wtime_per_rhs", + "wtime_per_rhs_e", + "wtime_per_rhs_i", + "PE_XIND", + "PE_YIND", + "MYPE", +] + def test_check_extensions(tmp_path): files_dir = tmp_path.joinpath("data") @@ -263,10 +276,12 @@ def test_strip_metadata(self): ds, metadata = _separate_metadata(original) - assert original.drop_vars( - METADATA_VARS + _BOUT_PER_PROC_VARIABLES + _BOUT_TIME_DEPENDENT_META_VARS, - errors="ignore", - ).equals(ds) + xrt.assert_equal( + original.drop_vars( + METADATA_VARS + _BOUT_PER_PROC_VARIABLES, errors="ignore" + ), + ds, + ) assert metadata["NXPE"] == 1 @@ -285,10 +300,7 @@ def test_single_file(self, tmp_path_factory, bout_xyt_example_files): xrt.assert_equal( actual.drop_vars(["x", "y", "z"]).load(), expected.drop_vars( - METADATA_VARS - + _BOUT_PER_PROC_VARIABLES - + _BOUT_TIME_DEPENDENT_META_VARS, - errors="ignore", + METADATA_VARS + _BOUT_PER_PROC_VARIABLES, errors="ignore" ), ) @@ -311,10 +323,7 @@ def test_squashed_file(self, tmp_path_factory, bout_xyt_example_files): xrt.assert_equal( actual.drop_vars(["x", "y", "z"]).load(), expected.drop_vars( - METADATA_VARS - + _BOUT_PER_PROC_VARIABLES - + _BOUT_TIME_DEPENDENT_META_VARS, - errors="ignore", + METADATA_VARS + _BOUT_PER_PROC_VARIABLES, errors="ignore" ), ) @@ -563,12 +572,7 @@ def test_toroidal(self, tmp_path_factory, bout_xyt_example_files): # check creation without writing to disk gives identical result fake_ds_list, fake_grid_ds = bout_xyt_example_files( - None, - nxpe=3, - nype=3, - nt=1, - syn_data_type="stepped", - grid="grid", + None, nxpe=3, nype=3, nt=1, syn_data_type="stepped", grid="grid" ) fake = open_boutdataset( datapath=fake_ds_list, geometry="toroidal", gridfilepath=fake_grid_ds @@ -930,21 +934,3 @@ def test_keep_yboundaries_doublenull_by_filenum( if not upper: expected = expected.isel(y=slice(None, -2, None)) xrt.assert_equal(expected, actual) - - @pytest.mark.parametrize("is_restart", [False, True]) - def test_trim_timing_info(self, is_restart): - ds = create_test_data(0) - from xbout.load import _BOUT_PER_PROC_VARIABLES - - # remove a couple of entries from _BOUT_PER_PROC_VARIABLES so we test that _trim - # does not fail if not all of them are present - _BOUT_PER_PROC_VARIABLES = _BOUT_PER_PROC_VARIABLES[:-2] - - for v in _BOUT_PER_PROC_VARIABLES: - ds[v] = 42.0 - ds = _trim( - ds, guards={}, keep_boundaries={}, nxpe=1, nype=1, is_restart=is_restart - ) - - expected = create_test_data(0) - xrt.assert_equal(ds, expected) diff --git a/xbout/utils.py b/xbout/utils.py index 66dd9593..33f93571 100644 --- a/xbout/utils.py +++ b/xbout/utils.py @@ -676,12 +676,7 @@ def _follow_boundary(ds, start_region, start_direction, xbndry, ybndry, Rcoord, for boundary in check_order[direction]: result = _bounding_surface_checks[boundary]( - ds_region, - boundary_points, - xbndry, - ybndry, - Rcoord, - Zcoord, + ds_region, boundary_points, xbndry, ybndry, Rcoord, Zcoord ) if result is not None: boundary_points, this_region, direction = result @@ -762,13 +757,7 @@ def _get_bounding_surfaces(ds, coords): start_direction = "upper_y" boundary, checked_regions = _follow_boundary( - ds, - start_region, - start_direction, - xbndry, - ybndry, - Rcoord, - Zcoord, + ds, start_region, start_direction, xbndry, ybndry, Rcoord, Zcoord ) boundaries = [boundary] @@ -832,9 +821,7 @@ def _get_bounding_surfaces(ds, coords): # Pack the result into a DataArray result = [ xr.DataArray( - boundary, - dims=("boundary", "coord"), - coords={"coord": [Rcoord, Zcoord]}, + boundary, dims=("boundary", "coord"), coords={"coord": [Rcoord, Zcoord]} ) for boundary in boundaries ] @@ -860,3 +847,14 @@ def _set_as_coord(ds, name): except ValueError: pass return ds + + +def _maybe_rename_dimension(ds, old_name, new_name): + if old_name in ds.dims and new_name != old_name: + # Rename dimension + ds = ds.swap_dims({old_name: new_name}) + if old_name in ds: + # Rename coordinate if it exists + ds = ds.rename({old_name: new_name}) + + return ds