diff --git a/docs/ArtificialInformation_Filter.ipynb b/docs/ArtificialInformation_Filter.ipynb new file mode 100644 index 00000000..c9f8d4cc --- /dev/null +++ b/docs/ArtificialInformation_Filter.ipynb @@ -0,0 +1,335 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "38ac6d1a", + "metadata": {}, + "source": [ + "**<<<<<<< local**" + ] + }, + { + "cell_type": "markdown", + "id": "895e1d5a", + "metadata": {}, + "source": [ + "# Artificial information filtering\n", + "\n", + "In simple terms the bitinformation is retrieved by checking how variable a bit pattern is. However, this approach cannot distinguish between actual information content and artifical information content. By studying the distribution of the information content the user can often identify clear cut-offs of real information content and artificial information content.\n", + "\n", + "The following example shows how such a separation of real information and artificial information can look like. To do so, artificial information is artificially added to an example dataset by applying linear quantization. Linear quantization is often applied to climate datasets (e.g. ERA5) and needs to be accounted for in order to retrieve meaningful bitinformation content. An algorithm that aims at detecting this artificial information itself is introduced." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c37dd36", + "metadata": {}, + "outputs": [], + "source": [ + "import xarray as xr\n", + "import xbitinfo as xb\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "id": "e8e1424f", + "metadata": {}, + "source": [ + "## Loading example dataset\n", + "We use here the openly accessible CONUS dataset. The dataset is available at full precision." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b18b9e24", + "metadata": {}, + "outputs": [], + "source": [ + "ds = xr.open_zarr(\n", + " \"s3://hytest/conus404/conus404_hourly.zarr\",\n", + " storage_options={\n", + " \"anon\": True,\n", + " \"requester_pays\": False,\n", + " \"client_kwargs\": {\"endpoint_url\": \"https://usgs.osn.mghpcc.org\"},\n", + " },\n", + ")\n", + "# selecting water vapor mixing ratio at 2 meters\n", + "data = ds[\"ACSWUPB\"]\n", + "# select subset of data for demonstration purposes\n", + "chunk = data.isel(time=slice(0, 9), y=slice(0, 525), x=slice(0, 525))\n", + "chunk" + ] + }, + { + "cell_type": "markdown", + "id": "535ce421", + "metadata": {}, + "source": [ + "## Creating dataset copy with artificial information\n", + "### Functions to encode and decode" + ] + }, + { + "cell_type": "markdown", + "id": "69543b4c", + "metadata": {}, + "source": [ + "**=======**" + ] + }, + { + "cell_type": "markdown", + "id": "1842f792", + "metadata": {}, + "source": [ + "# Artificial information filtering\n", + "\n", + "In simple terms the bitinformation is retrieved by checking how variable a bit pattern is. However, this approach cannot distinguish between actual information content and artifical information content. By studying the distribution of the information content the user can often identify clear cut-offs of real information content and artificial information content.\n", + "\n", + "The following example shows how such a separation of real information and artificial information can look like. To do so, artificial information is artificially added to an example dataset by applying linear quantization. Linear quantization is often applied to climate datasets (e.g. ERA5) and needs to be accounted for in order to retrieve meaningful bitinformation content. An algorithm that aims at detecting this artificial information itself is introduced." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb998fbb", + "metadata": {}, + "outputs": [], + "source": [ + "import xarray as xr\n", + "import xbitinfo as xb\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "id": "32ac97e0", + "metadata": {}, + "source": [ + "## Loading example dataset\n", + "We use here the openly accessible CONUS dataset. The dataset is available at full precision." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9639a618", + "metadata": {}, + "outputs": [], + "source": [ + "ds = xr.open_zarr(\n", + " \"s3://hytest/conus404/conus404_monthly.zarr\",\n", + " storage_options={\n", + " \"anon\": True,\n", + " \"requester_pays\": False,\n", + " \"client_kwargs\": {\"endpoint_url\": \"https://usgs.osn.mghpcc.org\"},\n", + " },\n", + ")\n", + "# selecting water vapor mixing ratio at 2 meters\n", + "data = ds[\"ACSWDNT\"]\n", + "# select subset of data for demonstration purposes\n", + "chunk = data.isel(time=slice(0, 2), y=slice(0, 1015), x=slice(0, 1050))\n", + "chunk" + ] + }, + { + "cell_type": "markdown", + "id": "3d735e4b", + "metadata": {}, + "source": [ + "## Creating dataset copy with artificial information\n", + "### Functions to encode and decode" + ] + }, + { + "cell_type": "markdown", + "id": "0d30feaa", + "metadata": {}, + "source": [ + "**>>>>>>> remote**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b3a7c7ae", + "metadata": {}, + "outputs": [], + "source": [ + "# Encoding function to compress data\n", + "def encode(chunk, scale, offset, dtype, astype):\n", + " enc = (chunk - offset) * scale\n", + " enc = np.around(enc)\n", + " enc = enc.astype(astype, copy=False)\n", + " return enc\n", + "\n", + "\n", + "# Decoding function to decompress data\n", + "def decode(enc, scale, offset, dtype, astype):\n", + " dec = (enc / scale) + offset\n", + " dec = dec.astype(dtype, copy=False)\n", + " return dec" + ] + }, + { + "cell_type": "markdown", + "id": "fa6f26c7", + "metadata": {}, + "source": [ + "### Transform dataset to introduce artificial information" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c09e3cf3", + "metadata": {}, + "outputs": [], + "source": [ + "xmin = np.min(chunk)\n", + "xmax = np.max(chunk)\n", + "scale = (2**16 - 1) / (xmax - xmin)\n", + "offset = xmin\n", + "enc = encode(chunk, scale, offset, \"f4\", \"u2\")\n", + "dec = decode(enc, scale, offset, \"f4\", \"u2\")" + ] + }, + { + "cell_type": "markdown", + "id": "7126810d", + "metadata": {}, + "source": [ + "## Comparison of bitinformation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05ef8a94", + "metadata": {}, + "outputs": [], + "source": [ + "# original dataset without artificial information\n", + "orig_info = xb.get_bitinformation(\n", + " xr.Dataset({\"w/o artif. info\": chunk}),\n", + " dim=\"x\",\n", + " implementation=\"python\",\n", + ")\n", + "\n", + "# dataset with artificial information\n", + "arti_info = xb.get_bitinformation(\n", + " xr.Dataset({\"w artif. info\": dec}),\n", + " dim=\"x\",\n", + " implementation=\"python\",\n", + ")\n", + "\n", + "# plotting distribution of bitwise information content\n", + "info = xr.merge([orig_info, arti_info])\n", + "plot = xb.plot_bitinformation(info)" + ] + }, + { + "cell_type": "markdown", + "id": "de1ecb7e", + "metadata": {}, + "source": [ + "The figure reveals that artificial information is introduced by applying linear quantization. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8600d4b8", + "metadata": {}, + "outputs": [], + "source": [ + "keepbits = xb.get_keepbits(info, inflevel=[0.99])\n", + "print(\n", + " f\"The number of keepbits increased from {keepbits['w/o artif. info'].item(0)} bits in the original dataset to {keepbits['w artif. info'].item(0)} bits in the dataset with artificial information.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "fa80f988", + "metadata": {}, + "source": [ + "In the following, a gradient based filter is introduced to remove this artificial information again so that even in case artificial information is present in a dataset the number of keepbits remains similar." + ] + }, + { + "cell_type": "markdown", + "id": "3f7a7c2e", + "metadata": {}, + "source": [ + "## Artificial information filter\n", + "The filter `gradient` works as follows:\n", + "\n", + "1. It determines the Cumulative Distribution Function(CDF) of the bitwise information content\n", + "2. It computes the gradient of the CDF to identify points where the gradient becomes close to a given tolerance indicating a drop in information.\n", + "3. Simultaneously, it keeps track of the minimum cumulative sum of information content which is threshold here, which signifies at least this much fraction of total information needs to be passed.\n", + "4. So the bit where the intersection of the gradient reaching the tolerance and the cumulative sum exceeding the threshold is our TrueKeepbits. All bits beyond this index are assumed to contain artificial information and are set to zero in order to cut them off.\n", + "5. You can see the above concept implemented in the function get_cdf_without_artificial_information in xbitinfo.py\n", + "\n", + "Please note that this filter relies on a clear separation between real and artificial information content and might not work in all cases." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0ab6633", + "metadata": {}, + "outputs": [], + "source": [ + "xb.get_keepbits(\n", + " arti_info,\n", + " inflevel=[0.99],\n", + " information_filter=\"Gradient\",\n", + " **{\"threshold\": 0.7, \"tolerance\": 0.001}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "21c6369d", + "metadata": {}, + "source": [ + "With the application of the filter the keepbits are closer/identical to their original value in the dataset without artificial information. The plot of the bitinformation visualizes this:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e9183b2", + "metadata": {}, + "outputs": [], + "source": [ + "plot = xb.plot_bitinformation(arti_info, information_filter=\"Gradient\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/environment.yml b/docs/environment.yml index 23e08aa9..a1ac1588 100644 --- a/docs/environment.yml +++ b/docs/environment.yml @@ -21,6 +21,9 @@ dependencies: - sphinx-book-theme>=0.1.7 - myst-nb - numcodecs>=0.10.0 + - intake-xarray + - metpy + - s3fs - pip - pip: - -e ../. diff --git a/docs/index.rst b/docs/index.rst index 4b38309f..445beefc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -96,17 +96,16 @@ Credits quick-start.ipynb - **User Guide** -* :doc:`chunking` +* :doc:`artificialinformation` .. toctree:: - :maxdepth: 1 - :hidden: - :caption: User Guide + :maxdepth: 1 + :hidden: + :caption: User Guide - chunking.ipynb + ArtificialInformation_Filter.ipynb **Help & Reference** diff --git a/environment.yml b/environment.yml index a75c755e..b728b45e 100644 --- a/environment.yml +++ b/environment.yml @@ -28,6 +28,9 @@ dependencies: - sphinx-book-theme - myst-nb - numcodecs>=0.10.0 + - pytest-lazy-fixture + - aiohttp + - s3fs - pip - pip: - -e . diff --git a/tests/test_get_keepbits.py b/tests/test_get_keepbits.py index cb24cd13..6e71a8a8 100644 --- a/tests/test_get_keepbits.py +++ b/tests/test_get_keepbits.py @@ -30,3 +30,144 @@ def test_get_keepbits_inflevel_dim(rasm_info_per_bit, inflevel): if isinstance(inflevel, (int, float)): inflevel = [inflevel] assert (keepbits.inflevel == inflevel).all() + + +def test_get_keepbits_informationFilter(): + """ + Test the `get_keepbits` function with different information filters. + + This test function checks the behavior of the `get_keepbits` function when applying gradient information filter. + The dataset contains artificial information and thus applying the filter should result in lesser number of bits + than what should be when filter is None. + + + Raises: + AssertionError: If the test conditions are not met. + + """ + + bit32_values = [ + "±", + "e1", + "e2", + "e3", + "e4", + "e5", + "e6", + "e7", + "e8", + "m1", + "m2", + "m3", + "m4", + "m5", + "m6", + "m7", + "m8", + "m9", + "m10", + "m11", + "m12", + "m13", + "m14", + "m15", + "m16", + "m17", + "m18", + "m19", + "m20", + "m21", + "m22", + "m23", + ] + data_variable = xr.DataArray( + data=[ + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 1.11799129e-01, + 8.19114977e-01, + 4.41578500e-01, + 3.25470303e-01, + 4.35195738e-01, + 2.81462993e-01, + 2.10719742e-01, + 1.46638224e-01, + 9.24534031e-02, + 4.41090879e-02, + 1.13842504e-02, + 8.20088050e-04, + 2.62239097e-06, + 7.11284508e-07, + 1.18183485e-06, + 9.49338973e-09, + 1.80859255e-07, + 7.72662891e-07, + 1.37865391e-05, + 2.11117224e-06, + 2.01353088e-07, + 3.20755770e-02, + 6.06721012e-03, + 2.25987148e-04, + 1.71530452e-06, + 5.13067595e-03, + ], + coords={"bit32": bit32_values, "dim": "x"}, + dims=["bit32"], + name="RH2", + ) + info_ds = xr.Dataset({"RH2": data_variable}) + Keepbits_FilterNone = xb.get_keepbits( + info_ds, + inflevel=[0.99], + information_filter=None, + **{"threshold": 0.7, "tolerance": 0.001} + ) + Keepbits_FilterNone_Value = Keepbits_FilterNone["RH2"].values + assert Keepbits_FilterNone_Value == 19 + + Keepbits_FilterGradient = xb.get_keepbits( + info_ds, + inflevel=[0.99], + information_filter="Gradient", + **{"threshold": 0.7, "tolerance": 0.001} + ) + Keepbits_FilterGradient_Value = Keepbits_FilterGradient["RH2"].values + assert Keepbits_FilterGradient_Value == 7 + + +def test_get_keepbits_informationFilter_1(): + """ + Test the `get_keepbits` function with different information filters. + + This test function checks the behavior of the `get_keepbits` function when applying gradient information filter. + The dataset does not contain artificial information and thus the number of keepbits when gradient filter is applied + should be equal to when filter is None. + + Raises: + AssertionError: If the test conditions are not met. + + """ + + ds = xr.tutorial.load_dataset("air_temperature") + info = xb.get_bitinformation(ds, dim="lat") + Keepbits_FilterNone = xb.get_keepbits( + info, + inflevel=[0.99], + information_filter=None, + **{"threshold": 0.7, "tolerance": 0.001} + ) + Keepbits_FilterNone_Value = Keepbits_FilterNone["air"].values + + Keepbits_FilterGradient = xb.get_keepbits( + info, + inflevel=[0.99], + information_filter="Gradient", + **{"threshold": 0.7, "tolerance": 0.001} + ) + + Keepbits_FilterGradient_Value = Keepbits_FilterGradient["air"].values + assert Keepbits_FilterNone_Value == Keepbits_FilterGradient_Value diff --git a/xbitinfo/graphics.py b/xbitinfo/graphics.py index bef4037a..7a3b93d6 100644 --- a/xbitinfo/graphics.py +++ b/xbitinfo/graphics.py @@ -2,7 +2,7 @@ import numpy as np import xarray as xr -from .xbitinfo import _cdf_from_info_per_bit, bit_partitioning, get_keepbits +from .xbitinfo import NMBITS, _cdf_from_info_per_bit, get_keepbits def add_bitinfo_labels( @@ -117,12 +117,16 @@ def add_bitinfo_labels( CDF = _cdf_from_info_per_bit(info_per_bit, dimension) CDF_DataArray = CDF[da.name] - data_type = np.dtype(dimension.replace("bit", "")) - _, _, n_exp, _ = bit_partitioning(data_type) if inflevels is None: inflevels = [] for i, keep in enumerate(keepbits): - mantissa_index = keep + n_exp + if dimension == "bit16": + mantissa_index = keep + 5 + if dimension == "bit32": + mantissa_index = keep + 8 + if dimension == "bit64": + mantissa_index = keep + 11 + inflevels.append(CDF_DataArray[mantissa_index].values) if keepbits is None: @@ -181,29 +185,7 @@ def add_bitinfo_labels( t_keepbits.set_bbox(dict(facecolor="white", alpha=0.9, edgecolor="white")) -def split_dataset_by_dims(info_per_bit): - """Split dataset by its dimensions. - - Parameters - ---------- - info_per_bit : dict - Information content of each bit for each variable in ``da``. This is the output from :py:func:`xbitinfo.xbitinfo.get_bitinformation`. - - Returns - ------- - var_by_dim : dict - Dictionary containing the dimensions of the datasets as keys and the dataset using the dimension as value. - """ - var_by_dim = {d: [] for d in info_per_bit.dims} - for var in info_per_bit.data_vars: - assert ( - len(info_per_bit[var].dims) == 1 - ), f"Variable {var} has more than one dimension." - var_by_dim[info_per_bit[var].dims[0]].append(var) - return var_by_dim - - -def plot_bitinformation(bitinfo, cmap="turku", crop=None): +def plot_bitinformation(bitinfo, information_filter=None, cmap="turku"): """Plot bitwise information content as in Klöwer et al. 2021 Figure 2. Klöwer, M., Razinger, M., Dominguez, J. J., Düben, P. D., & Palmer, T. N. (2021). @@ -216,9 +198,12 @@ def plot_bitinformation(bitinfo, cmap="turku", crop=None): Containing the bitwise information content for each variable cmap : str or plt.cm Colormap. Defaults to ``"turku"``. - crop : int - Maximum bits to show in figure. + Kwargs + threshold(` `float ``) : defaults to ``0.7`` + Minimum cumulative sum of information content before artificial information filter is applied. + tolerance(` `float ``) : defaults to ``0.001`` + The tolerance is the value below which gradient starts becoming constant Returns ------- fig : matplotlib figure @@ -228,79 +213,121 @@ def plot_bitinformation(bitinfo, cmap="turku", crop=None): >>> ds = xr.tutorial.load_dataset("air_temperature") >>> info_per_bit = xb.get_bitinformation(ds, dim="lon") >>> xb.plot_bitinformation(info_per_bit) -
+
""" import matplotlib.pyplot as plt - bitinfo = bitinfo.squeeze() + assert bitinfo.coords["dim"].shape <= ( + 1, + ), "Only bitinfo along one dimension is supported at the moment. Please select dimension before plotting." + assert ( - "dim" not in bitinfo.dims - ), "Found dependence of bitinformation on dimension. Please reduce data first by e.g. `bitinfo.max(dim='dim')`" - vars_by_dim = split_dataset_by_dims(bitinfo) - bitinfo_all = bitinfo - subfigure_data = [None] * len(vars_by_dim) - for d, (dim, vars) in enumerate(vars_by_dim.items()): - bitinfo = bitinfo_all[vars] - data_type = np.dtype(dim.replace("bit", "")) - n_bits, n_sign, n_exp, n_mant = bit_partitioning(data_type) - nonmantissa_bits = n_bits - n_mant - if crop is None: - bits_to_show = n_bits - else: - bits_to_show = int(np.min([crop, n_bits])) - nvars = len(bitinfo) - varnames = list(bitinfo.keys()) + "bit32" in bitinfo.dims + ), "currently only works properly for float32 data, looking forward to your PR closing https://github.com/observingClouds/xbitinfo/issues/168" + nvars = len(bitinfo) + varnames = bitinfo.keys() + + if information_filter == "Gradient": + infbits_dict = get_keepbits( + bitinfo, 0.99, information_filter, **{"threshold": 0.7, "tolerance": 0.001} + ) + infbits100_dict = get_keepbits( + bitinfo, + 0.999999999, + information_filter, + **{"threshold": 0.7, "tolerance": 0.001}, + ) + else: infbits_dict = get_keepbits(bitinfo, 0.99) infbits100_dict = get_keepbits(bitinfo, 0.999999999) - ICnan = np.zeros((nvars, 64)) - infbits = np.zeros(nvars) - infbits100 = np.zeros(nvars) - ICnan[:, :] = np.nan - for v, var in enumerate(varnames): - ic = bitinfo[var].squeeze(drop=True) - ICnan[v, : len(ic)] = ic - # infbits are all bits, infbits_dict were mantissa bits - infbits[v] = infbits_dict[var] + nonmantissa_bits - infbits100[v] = infbits100_dict[var] + nonmantissa_bits - ICnan = np.where(ICnan == 0, np.nan, ICnan) - ICcsum = np.nancumsum(ICnan, axis=1) - - infbitsy = np.hstack([0, np.repeat(np.arange(1, nvars), 2), nvars]) - infbitsx = np.repeat(infbits, 2) - infbitsx100 = np.repeat(infbits100, 2) - - fig_height = np.max([4, 4 + (nvars - 10) * 0.2]) # auto adjust to nvars - - subfigure_data[d] = {} - subfigure_data[d]["fig_height"] = fig_height - subfigure_data[d]["nvars"] = nvars - subfigure_data[d]["varnames"] = varnames - subfigure_data[d]["ICnan"] = ICnan - subfigure_data[d]["ICcsum"] = ICcsum - subfigure_data[d]["infbits"] = infbits - subfigure_data[d]["infbitsx"] = infbitsx - subfigure_data[d]["infbitsy"] = infbitsy - subfigure_data[d]["infbitsx100"] = infbitsx100 - subfigure_data[d]["nbits"] = (n_sign, n_exp, n_bits, n_mant, nonmantissa_bits) - subfigure_data[d]["bits_to_show"] = bits_to_show - - fig_heights = [subfig["fig_height"] for subfig in subfigure_data] - fig = plt.figure(figsize=(12, sum(fig_heights) + 2 * 2)) - fig_heights_incl_cax = fig_heights + [2 / (sum(fig_heights) + 2)] * 2 - grid = fig.add_gridspec( - len(subfigure_data) + 2, 1, height_ratios=fig_heights_incl_cax + ICnan = np.zeros((nvars, 64)) + infbits = np.zeros(nvars) + infbits100 = np.zeros(nvars) + ICnan[:, :] = np.nan + for v, var in enumerate(varnames): + ic = bitinfo[var].squeeze(drop=True) + ICnan[v, : len(ic)] = ic + # infbits are all bits, infbits_dict were mantissa bits + infbits[v] = infbits_dict[var] + NMBITS[len(ic)] + infbits100[v] = infbits100_dict[var] + NMBITS[len(ic)] + ICnan = np.where(ICnan == 0, np.nan, ICnan) + ICcsum = np.nancumsum(ICnan, axis=1) + + infbitsy = np.hstack([0, np.repeat(np.arange(1, nvars), 2), nvars]) + infbitsx = np.repeat(infbits, 2) + infbitsx100 = np.repeat(infbits100, 2) + + fig_height = np.max([4, 4 + (nvars - 10) * 0.2]) # auto adjust to nvars + fig, ax1 = plt.subplots(1, 1, figsize=(12, fig_height), sharey=True) + ax1.invert_yaxis() + ax1.set_box_aspect(1 / 32 * nvars) + plt.tight_layout(rect=[0.06, 0.18, 0.8, 0.98]) + pos = ax1.get_position() + cax = fig.add_axes([pos.x0, 0.12, pos.x1 - pos.x0, 0.02]) + + ax1right = ax1.twinx() + ax1right.invert_yaxis() + ax1right.set_box_aspect(1 / 32 * nvars) + + if cmap == "turku": + import cmcrameri.cm as cmc + + cmap = cmc.turku_r + pcm = ax1.pcolormesh(ICnan, vmin=0, vmax=1, cmap=cmap) + cbar = plt.colorbar(pcm, cax=cax, orientation="horizontal") + cbar.set_label("information content [bit]") + + # 99% of real information enclosed + ax1.plot( + np.hstack([infbits, infbits[-1]]), + np.arange(nvars + 1), + "C1", + ds="steps-pre", + zorder=10, + label="99% of\ninformation", ) - axs = [] - for i in range(len(subfigure_data) + 2): - ax = fig.add_subplot(grid[i, 0]) - axs.append(ax) + # grey shading + ax1.fill_betweenx( + infbitsy, infbitsx, np.ones(len(infbitsx)) * 32, alpha=0.4, color="grey" + ) + ax1.fill_betweenx( + infbitsy, infbitsx100, np.ones(len(infbitsx)) * 32, alpha=0.1, color="c" + ) + ax1.fill_betweenx( + infbitsy, + infbitsx100, + np.ones(len(infbitsx)) * 32, + alpha=0.3, + facecolor="none", + edgecolor="c", + ) - if isinstance(axs, plt.Axes): - axs = [axs] + # for legend only + ax1.fill_betweenx( + [-1, -1], + [-1, -1], + [-1, -1], + color="burlywood", + label="last 1% of\ninformation", + alpha=0.5, + ) + ax1.fill_betweenx( + [-1, -1], + [-1, -1], + [-1, -1], + facecolor="teal", + edgecolor="c", + label="false information", + alpha=0.3, + ) + ax1.fill_betweenx([-1, -1], [-1, -1], [-1, -1], color="w", label="unused bits") + + ax1.axvline(1, color="k", lw=1, zorder=3) + ax1.axvline(9, color="k", lw=1, zorder=3) fig.suptitle( "Real bitwise information content", @@ -310,187 +337,48 @@ def plot_bitinformation(bitinfo, cmap="turku", crop=None): horizontalalignment="left", ) - if cmap == "turku": - import cmcrameri.cm as cmc - - cmap = cmc.turku_r - - max_bits_to_show = np.max([d["bits_to_show"] for d in subfigure_data]) - - for d, subfig in enumerate(subfigure_data): - infbits = subfig["infbits"] - nvars = subfig["nvars"] - n_sign, n_exp, n_bits, n_mant, nonmantissa_bits = subfig["nbits"] - ICcsum = subfig["ICcsum"] - ICnan = subfig["ICnan"] - infbitsy = subfig["infbitsy"] - infbitsx = subfig["infbitsx"] - infbitsx100 = subfig["infbitsx100"] - varnames = subfig["varnames"] - bits_to_show = subfig["bits_to_show"] - - mbits_to_show = bits_to_show - nonmantissa_bits - - axs[d].invert_yaxis() - axs[d].set_box_aspect(1 / max_bits_to_show * nvars) - - ax1right = axs[d].twinx() - ax1right.invert_yaxis() - ax1right.set_box_aspect(1 / max_bits_to_show * nvars) - - pcm = axs[d].pcolormesh(ICnan, vmin=0, vmax=1, cmap=cmap) - - if d == len(subfigure_data) - 1: - cax = axs[len(subfigure_data)] - lax = axs[len(subfigure_data) + 1] - lax.axis("off") - cbar = plt.colorbar(pcm, cax=cax, orientation="horizontal") - cbar.set_label("information content [bit]") - - # 99% of real information enclosed - l0 = axs[d].plot( - np.hstack([infbits, infbits[-1]]), - np.arange(nvars + 1), - "C1", - ds="steps-pre", - zorder=10, - label="99% of\ninformation", + ax1.set_xlim(0, 32) + ax1.set_ylim(nvars, 0) + ax1right.set_ylim(nvars, 0) + + ax1.set_yticks(np.arange(nvars) + 0.5) + ax1right.set_yticks(np.arange(nvars) + 0.5) + ax1.set_yticklabels(varnames) + ax1right.set_yticklabels([f"{i:4.1f}" for i in ICcsum[:, -1]]) + ax1right.set_ylabel("total information per value [bit]") + + ax1.text( + infbits[0] + 0.1, + 0.8, + f"{int(infbits[0]-9)} mantissa bits", + fontsize=8, + color="saddlebrown", + ) + for i in range(1, nvars): + ax1.text( + infbits[i] + 0.1, + (i) + 0.8, + f"{int(infbits[i]-9)}", + fontsize=8, + color="saddlebrown", ) - # grey shading - axs[d].fill_betweenx( - infbitsy, - infbitsx, - np.ones(len(infbitsx)) * bits_to_show, - alpha=0.4, - color="grey", - ) - axs[d].fill_betweenx( - infbitsy, - infbitsx100, - np.ones(len(infbitsx)) * bits_to_show, - alpha=0.1, - color="c", - ) - axs[d].fill_betweenx( - infbitsy, - infbitsx100, - np.ones(len(infbitsx)) * bits_to_show, - alpha=0.3, - facecolor="none", - edgecolor="c", - ) + ax1.set_xticks([1, 9]) + ax1.set_xticks(np.hstack([np.arange(1, 8), np.arange(9, 32)]), minor=True) + ax1.set_xticklabels([]) + ax1.text(0, nvars + 1.2, "sign", rotation=90) + ax1.text(2, nvars + 1.2, "exponent bits", color="darkslategrey") + ax1.text(10, nvars + 1.2, "mantissa bits") - # for legend only - l1 = axs[d].fill_betweenx( - [-1, -1], - [-1, -1], - [-1, -1], - color="burlywood", - label="last 1% of\ninformation", - alpha=0.5, + for i in range(1, 9): + ax1.text( + i + 0.5, nvars + 0.5, i, ha="center", fontsize=7, color="darkslategrey" ) - l2 = axs[d].fill_betweenx( - [-1, -1], - [-1, -1], - [-1, -1], - facecolor="teal", - edgecolor="c", - label="false information", - alpha=0.3, - ) - l3 = axs[d].fill_betweenx( - [-1, -1], [-1, -1], [-1, -1], color="w", label="unused bits", edgecolor="k" - ) - - if n_sign > 0: - axs[d].axvline(n_sign, color="k", lw=1, zorder=3) - axs[d].axvline(nonmantissa_bits, color="k", lw=1, zorder=3) - axs[d].set_ylim(nvars, 0) - ax1right.set_ylim(nvars, 0) - - axs[d].set_yticks(np.arange(nvars) + 0.5) - ax1right.set_yticks(np.arange(nvars) + 0.5) - axs[d].set_yticklabels(varnames) - ax1right.set_yticklabels([f"{i:4.1f}" for i in ICcsum[:, -1]]) - if d == len(subfigure_data) // 2: - ax1right.set_ylabel("total information\nper value [bit]") - - axs[d].text( - infbits[0] + 0.1, - 0.8, - f"{int(infbits[0]-nonmantissa_bits)} mantissa bits", - fontsize=8, - color="saddlebrown", - ) - for i in range(1, nvars): - axs[d].text( - infbits[i] + 0.1, - (i) + 0.8, - f"{int(infbits[i]-9)}", - fontsize=8, - color="saddlebrown", - ) - - major_xticks = np.array([n_sign, n_sign + n_exp, n_bits], dtype="int") - axs[d].set_xticks(major_xticks[major_xticks <= bits_to_show]) - minor_xticks = np.hstack( - [ - np.arange(n_sign, nonmantissa_bits - 1), - np.arange(nonmantissa_bits, n_bits - 1), - ] - ) - axs[d].set_xticks( - minor_xticks[minor_xticks <= bits_to_show], - minor=True, - ) - axs[d].set_xticklabels([]) - if n_sign > 0: - axs[d].text(0, nvars + 1.2, "sign", rotation=90) - if n_exp > 0: - axs[d].text( - n_sign + n_exp / 2, - nvars + 1.2, - "exponent bits", - color="darkslategrey", - horizontalalignment="center", - verticalalignment="center", - ) - axs[d].text( - n_sign + n_exp + mbits_to_show / 2, - nvars + 1.2, - "mantissa bits", - horizontalalignment="center", - verticalalignment="center", - ) + for i in range(1, 24): + ax1.text(8 + i + 0.5, nvars + 0.5, i, ha="center", fontsize=7) - # Set xticklabels - ## Set exponent labels - for e, i in enumerate(range(n_sign, np.min([n_sign + n_exp, bits_to_show]))): - axs[d].text( - i + 0.5, - nvars + 0.5, - e + 1, - ha="center", - fontsize=7, - color="darkslategrey", - ) - ## Set mantissa labels - for m, i in enumerate( - range(n_sign + n_exp, np.min([n_sign + n_exp + n_mant, bits_to_show])) - ): - axs[d].text(i + 0.5, nvars + 0.5, m + 1, ha="center", fontsize=7) - - if d == len(subfigure_data) - 1: - lax.legend( - bbox_to_anchor=(0.5, 0), - loc="center", - framealpha=0.6, - ncol=4, - handles=[l0[0], l1, l2, l3], - ) - axs[d].set_xlim(0, bits_to_show) + ax1.legend(bbox_to_anchor=(1.08, 0.5), loc="center left", framealpha=0.6) fig.show() diff --git a/xbitinfo/xbitinfo.py b/xbitinfo/xbitinfo.py index 9a98f408..9d9933ac 100644 --- a/xbitinfo/xbitinfo.py +++ b/xbitinfo/xbitinfo.py @@ -33,38 +33,31 @@ jl.eval("include(Main.path)") -def bit_partitioning(dtype): - if dtype.kind == "f": - n_bits = np.finfo(dtype).bits - n_sign = 1 - n_exponent = np.finfo(dtype).nexp - n_mantissa = np.finfo(dtype).nmant - elif dtype.kind == "i": - n_bits = np.iinfo(dtype).bits - n_sign = 1 - n_exponent = 0 - n_mantissa = n_bits - n_sign - elif dtype.kind == "u": - n_bits = np.iinfo(dtype).bits - n_sign = 0 - n_exponent = 0 - n_mantissa = n_bits - n_sign +NMBITS = {64: 12, 32: 9, 16: 6} # number of non mantissa bits for given dtype + + +def get_bit_coords(dtype_size): + """Get coordinates for bits assuming float dtypes.""" + if dtype_size == 16: + coords = ( + ["±"] + + [f"e{int(i)}" for i in range(1, 6)] + + [f"m{int(i-5)}" for i in range(6, 16)] + ) + elif dtype_size == 32: + coords = ( + ["±"] + + [f"e{int(i)}" for i in range(1, 9)] + + [f"m{int(i-8)}" for i in range(9, 32)] + ) + elif dtype_size == 64: + coords = ( + ["±"] + + [f"e{int(i)}" for i in range(1, 12)] + + [f"m{int(i-11)}" for i in range(12, 64)] + ) else: - raise ValueError(f"dtype {dtype} neither known nor implemented.") - assert ( - n_sign + n_exponent + n_mantissa == n_bits - ), "The components of the datatype could not be safely inferred." - return n_bits, n_sign, n_exponent, n_mantissa - - -def get_bit_coords(dtype): - """Get coordinates for bits based on dtype.""" - n_bits, n_sign, n_exponent, n_mantissa = bit_partitioning(dtype) - coords = ( - n_sign * ["±"] - + [f"e{int(i)}" for i in range(1, n_exponent + 1)] - + [f"m{int(i)}" for i in range(1, n_mantissa + 1)] - ) + raise ValueError(f"dtype of size {dtype_size} neither known nor implemented.") return coords @@ -72,13 +65,13 @@ def dict_to_dataset(info_per_bit): """Convert keepbits dictionary to :py:class:`xarray.Dataset`.""" dsb = xr.Dataset() for v in info_per_bit.keys(): - dtype = np.dtype(info_per_bit[v]["dtype"]) + dtype_size = len(info_per_bit[v]["bitinfo"]) dim = info_per_bit[v]["dim"] - dim_name = f"bit{dtype}" + dim_name = f"bit{dtype_size}" dsb[v] = xr.DataArray( info_per_bit[v]["bitinfo"], dims=[dim_name], - coords={dim_name: get_bit_coords(dtype), "dim": dim}, + coords={dim_name: get_bit_coords(dtype_size), "dim": dim}, name=v, attrs={ "long_name": f"{v} bitwise information", @@ -284,7 +277,6 @@ def _jl_get_bitinformation(ds, var, axis, dim, kwargs={}): ) info_per_bit["dim"] = dim info_per_bit["axis"] = axis_jl - 1 - info_per_bit["dtype"] = str(ds[var].dtype) return info_per_bit @@ -320,7 +312,6 @@ def _py_get_bitinformation(ds, var, axis, dim, kwargs={}): info_per_bit["bitinfo"] = pb.bitinformation(X, axis=axis).compute() info_per_bit["dim"] = dim info_per_bit["axis"] = axis - info_per_bit["dtype"] = str(ds[var].dtype) return info_per_bit @@ -385,7 +376,126 @@ def load_bitinformation(label): raise FileNotFoundError(f"No bitinformation could be found at {label+'.json'}") -def get_keepbits(info_per_bit, inflevel=0.99): +def get_cdf_without_artificial_information( + info_per_bit, bitdim, threshold, tolerance, bit_vars +): + """ + Calculate a Cumulative Distribution Function (CDF) with artificial information removal. + + This function calculates a modified CDF for a given set of bit information and variable dimensions, + removing artificial information while preserving the desired threshold of information content. + + 1.)The function's aim is to return the cdf in a way that artificial information gets removed. + 2.)This function calculates the CDF using the provided information content per bit dataset. + 3.)It then computes the gradient of the CDF values to identify points where the gradient becomes close to the given tolerance, + indicating a drop in information. + 4.)Simultaneously, it keeps track of the minimum cumulative sum of information content which is threshold here, which signifies atleast + this much fraction of total information needs to be passed. + 5.)So the bit where the intersection of the gradient reaching the tolerance and the cumulative sum exceeding the threshold. All bits beyond this + index are assumed to contain artificial information and are set to zero in the resulting CDF. + + + Parameters: + ----------- + info_per_bit : :py:class: 'xarray.Dataset' + Information content of each bit. This is the output from :py:func:`xbitinfo.xbitinfo.get_bitinformation`. + bitdim : str + The dimension representing the bit information. + threshold : float + Minimum cumulative sum of information content before artificial information filter is applied. + tolerance : float + The tolerance is the value below which gradient starts becoming constant + bit_vars : list + List of variable names of the dataset. + + Returns: + -------- + xarray.Dataset + A modified CDF dataset with artificial information removed. + + Example: + -------- + >>> ds = xr.tutorial.load_dataset("air_temperature") + >>> info = xb.get_bitinformation(ds) + >>> get_keepbits( + ... info, + ... inflevel=[0.99], + ... information_filter="Gradient", + ... **{"threshold": 0.7, "tolerance": 0.001} + ... ) + + Dimensions: (dim: 3, inflevel: 1) + Coordinates: + * dim (dim) = threshold * infSum: + infbits = i + break + + for i in range(0, infbits + 1): + # Normalize CDF values for elements up to 'infbits'. + cdf_array[i] = cdf_array[i] / cdf_array[infbits] + + cdf_array[(infbits + 1) :] = 1 + return cdf + + +def get_keepbits(info_per_bit, inflevel=0.99, information_filter=None, **kwargs): """Get the number of mantissa bits to keep. To be used in :py:func:`xbitinfo.bitround.xr_bitround` and :py:func:`xbitinfo.bitround.jl_bitround`. Parameters @@ -395,6 +505,13 @@ def get_keepbits(info_per_bit, inflevel=0.99): inflevel : float or list Level of information that shall be preserved. + Kwargs + threshold(` `float ``) : defaults to ``0.7`` + Minimum cumulative sum of information content before artificial information filter is applied. + tolerance(` `float ``) : defaults to ``0.001`` + The tolerance is the value below which gradient starts becoming constant + + Returns ------- keepbits : dict @@ -405,36 +522,36 @@ def get_keepbits(info_per_bit, inflevel=0.99): >>> ds = xr.tutorial.load_dataset("air_temperature") >>> info_per_bit = xb.get_bitinformation(ds, dim="lon") >>> xb.get_keepbits(info_per_bit) - Size: 28B + Dimensions: (inflevel: 1) Coordinates: - dim >> xb.get_keepbits(info_per_bit, inflevel=0.99999999) - Size: 28B + Dimensions: (inflevel: 1) Coordinates: - dim >> xb.get_keepbits(info_per_bit, inflevel=1.0) - Size: 28B + Dimensions: (inflevel: 1) Coordinates: - dim >> info_per_bit = xb.get_bitinformation(ds) >>> xb.get_keepbits(info_per_bit) - Size: 80B + Dimensions: (dim: 3, inflevel: 1) Coordinates: - * dim (dim) 1.0).any(): raise ValueError("Please provide `inflevel` from interval [0.,1.]") - for bitdim in [ - "bitfloat16", - "bitfloat32", - "bitfloat64", - "bitint16", - "bitint32", - "bitint64", - "bituint16", - "bituint32", - "bituint64", - ]: + for bitdim in ["bit16", "bit32", "bit64"]: # get only variables of bitdim bit_vars = [v for v in info_per_bit.data_vars if bitdim in info_per_bit[v].dims] if bit_vars != []: - cdf = _cdf_from_info_per_bit(info_per_bit[bit_vars], bitdim) - data_type = np.dtype(bitdim.replace("bit", "")) - n_bits, _, _, n_mant = bit_partitioning(data_type) - bitdim_non_mantissa_bits = n_bits - n_mant + if information_filter == "Gradient": + cdf = get_cdf_without_artificial_information( + info_per_bit[bit_vars], + bitdim, + kwargs["threshold"], + kwargs["tolerance"], + bit_vars, + ) + else: + cdf = _cdf_from_info_per_bit(info_per_bit[bit_vars], bitdim) + bitdim_non_mantissa_bits = NMBITS[int(bitdim[3:])] keepmantissabits_bitdim = ( (cdf > inflevel).argmax(bitdim) + 1 - bitdim_non_mantissa_bits ) # keep all mantissa bits for 100% information if 1.0 in inflevel: - bitdim_all_mantissa_bits = n_bits - bitdim_non_mantissa_bits + bitdim_all_mantissa_bits = int(bitdim[3:]) - bitdim_non_mantissa_bits keepall = xr.ones_like(keepmantissabits_bitdim.sel(inflevel=1.0)) * ( bitdim_all_mantissa_bits ) @@ -701,7 +815,7 @@ class JsonCustomEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, (np.ndarray, np.number)): return obj.tolist() - elif isinstance(obj, complex): + elif isinstance(obj, (complex, np.complex)): return [obj.real, obj.imag] elif isinstance(obj, set): return list(obj)