Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Zarr compatibility functions #478

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ jobs:
- name: Test with pytest
shell: bash -l {0}
run: |
export ZARR_V3_EXPERIMENTAL_API=1
pytest -v --cov
11 changes: 6 additions & 5 deletions kerchunk/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
import numcodecs
import numcodecs.abc
import numpy as np
import zarr

from fsspec.implementations.reference import LazyReferenceMapper


from kerchunk.utils import class_factory
from kerchunk.utils import class_factory, _zarr_open
from kerchunk.codecs import AsciiTableCodec, VarArrCodec

try:
Expand Down Expand Up @@ -40,6 +38,7 @@ def process_file(
inline_threshold=100,
primary_attr_to_group=False,
out=None,
zarr_version=None,
):
"""
Create JSON references for a single FITS file as a zarr group
Expand All @@ -62,7 +61,9 @@ def process_file(
This allows you to supply an fsspec.implementations.reference.LazyReferenceMapper
to write out parquet as the references get filled, or some other dictionary-like class
to customise how references get stored

zarr_version: int
The desired zarr spec version to target (currently 2 or 3). The default
of None will use the default zarr version.
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
Expand All @@ -72,7 +73,7 @@ def process_file(

storage_options = storage_options or {}
out = out or {}
g = zarr.open(out)
g = _zarr_open(out, zarr_version=zarr_version)

with fsspec.open(url, mode="rb", **storage_options) as f:
infile = fits.open(f, do_not_scale_image_data=True)
Expand Down
11 changes: 7 additions & 4 deletions kerchunk/grib2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
)

import fsspec
import zarr
import xarray
import numpy as np

from kerchunk.utils import class_factory, _encode_for_JSON
from kerchunk.utils import class_factory, _encode_for_JSON, _zarr_init_group_and_store
from kerchunk.codecs import GRIBCodec
from kerchunk.combine import MultiZarrToZarr, drop

Expand Down Expand Up @@ -113,6 +112,7 @@ def scan_grib(
inline_threshold=100,
skip=0,
filter={},
zarr_version=None,
):
"""
Generate references for a GRIB2 file
Expand All @@ -134,6 +134,9 @@ def scan_grib(
the exact value or is in the given set, are processed.
E.g., the cf-style filter ``{'typeOfLevel': 'heightAboveGround', 'level': 2}``
only keeps messages where heightAboveGround==2.
zarr_version: int
The desired zarr spec version to target (currently 2 or 3). The default
of None will use the default zarr version.

Returns
-------
Expand Down Expand Up @@ -192,7 +195,7 @@ def scan_grib(
if good is False:
continue

z = zarr.open_group(store)
z, store = _zarr_init_group_and_store(store, zarr_version=zarr_version)
global_attrs = {
f"GRIB_{k}": m[k]
for k in cfgrib.dataset.GLOBAL_ATTRIBUTES_KEYS
Expand Down Expand Up @@ -399,7 +402,7 @@ def grib_tree(

# TODO allow passing a LazyReferenceMapper as output?
zarr_store = {}
zroot = zarr.open_group(store=zarr_store)
zroot, zarr_store = _zarr_init_group_and_store(zarr_store, overwrite=False)

aggregations: Dict[str, List] = defaultdict(list)
aggregation_dims: Dict[str, Set] = defaultdict(set)
Expand Down
21 changes: 8 additions & 13 deletions kerchunk/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numcodecs

from .codecs import FillStringsCodec
from .utils import _encode_for_JSON
from .utils import _encode_for_JSON, encode_fill_value, _zarr_init_group_and_store

try:
import h5py
Expand All @@ -21,12 +21,6 @@
"for more details."
)

try:
from zarr.meta import encode_fill_value
except ModuleNotFoundError:
# https://github.com/zarr-developers/zarr-python/issues/2021
from zarr.v2.meta import encode_fill_value

lggr = logging.getLogger("h5-to-zarr")
_HIDDEN_ATTRS = { # from h5netcdf.attrs
"REFERENCE_LIST",
Expand Down Expand Up @@ -71,10 +65,10 @@ class SingleHdf5ToZarr:
encode: save the ID-to-value mapping in a codec, to produce the real values at read
time; requires this library to be available. Can be efficient storage where there
are few unique values.
out: dict-like or None
out: dict-like, StoreLike, or None
This allows you to supply an fsspec.implementations.reference.LazyReferenceMapper
to write out parquet as the references get filled, or some other dictionary-like class
to customise how references get stored
or a ZarrV3 StoreLike to write out parquet as the references get filled,
or some other dictionary-like class to customise how references get stored
"""

def __init__(
Expand All @@ -87,6 +81,7 @@ def __init__(
error="warn",
vlen_encode="embed",
out=None,
zarr_version=None,
):

# Open HDF5 file in read mode...
Expand All @@ -111,9 +106,9 @@ def __init__(
if vlen_encode not in ["embed", "null", "leave", "encode"]:
raise NotImplementedError
self.vlen = vlen_encode
self.store = out or {}
self._zroot = zarr.group(store=self.store, overwrite=True)

self._zroot, self.store = _zarr_init_group_and_store(
out or {}, zarr_version=zarr_version or 2
)
self._uri = url
self.error = error
lggr.debug(f"HDF5 file URI: {self._uri}")
Expand Down
59 changes: 45 additions & 14 deletions kerchunk/netCDF3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from fsspec.implementations.reference import LazyReferenceMapper
import fsspec

from kerchunk.utils import _encode_for_JSON, inline_array
from kerchunk.utils import (
_encode_for_JSON,
inline_array,
_zarr_open,
)

try:
from scipy.io._netcdf import ZERO, NC_VARIABLE, netcdf_file, netcdf_variable
Expand All @@ -31,6 +35,7 @@ def __init__(
inline_threshold=100,
max_chunk_size=0,
out=None,
zarr_version=None,
**kwargs,
):
"""
Expand All @@ -52,6 +57,9 @@ def __init__(
This allows you to supply an fsspec.implementations.reference.LazyReferenceMapper
to write out parquet as the references get filled, or some other dictionary-like class
to customise how references get stored
zarr_version: int
The desired zarr spec version to target (currently 2 or 3). The default
of None will use the default zarr version.
args, kwargs: passed to scipy superclass ``scipy.io.netcdf.netcdf_file``
"""
assert kwargs.pop("mmap", False) is False
Expand All @@ -63,6 +71,7 @@ def __init__(
self.chunks = {}
self.threshold = inline_threshold
self.max_chunk_size = max_chunk_size
self.zarr_version = zarr_version
self.out = out or {}
self.storage_options = storage_options
self.fp = fsspec.open(filename, **(storage_options or {})).open()
Expand Down Expand Up @@ -164,10 +173,9 @@ def translate(self):
Parameters
----------
"""
import zarr

out = self.out
z = zarr.open(out, mode="w")
zroot = _zarr_open(out, mode="w")
for dim, var in self.variables.items():
if dim in self.chunks:
shape = self.chunks[dim][-1]
Expand All @@ -191,18 +199,25 @@ def translate(self):
fill = float(fill)
if fill is not None and var.data.dtype.kind == "i":
fill = int(fill)
arr = z.create_dataset(
arr = zroot.create_dataset(
name=dim,
shape=shape,
dtype=var.data.dtype,
fill_value=fill,
chunks=shape,
compression=None,
overwrite=True,
)
part = ".".join(["0"] * len(shape)) or "0"
k = f"{dim}/{part}"
out[k] = [
self.filename,

if self.zarr_version == 3:
part = "/".join(["0"] * len(shape)) or "0"
key = f"data/root/{dim}/c{part}"
else:
part = ".".join(["0"] * len(shape)) or "0"

key = f"{dim}/{part}"

self.out[key] = [self.filename] + [
int(self.chunks[dim][0]),
int(self.chunks[dim][1]),
]
Expand Down Expand Up @@ -245,13 +260,14 @@ def translate(self):
fill = float(fill)
if fill is not None and base.kind == "i":
fill = int(fill)
arr = z.create_dataset(
arr = zroot.create_dataset(
name=name,
shape=shape,
dtype=base,
fill_value=fill,
chunks=(1,) + dtype.shape,
compression=None,
overwrite=True,
)
arr.attrs.update(
{
Expand All @@ -266,18 +282,33 @@ def translate(self):

arr.attrs["_ARRAY_DIMENSIONS"] = list(var.dimensions)

suffix = (
("." + ".".join("0" for _ in dtype.shape)) if dtype.shape else ""
)
if self.zarr_version == 3:
suffix = (
("/" + "/".join("0" for _ in dtype.shape))
if dtype.shape
else ""
)
else:
suffix = (
("." + ".".join("0" for _ in dtype.shape))
if dtype.shape
else ""
)

for i in range(outer_shape):
out[f"{name}/{i}{suffix}"] = [
if self.zarr_version == 3:
key = f"data/root/{name}/c{i}{suffix}"
else:
key = f"{name}/{i}{suffix}"

self.out[key] = [
self.filename,
int(offset + i * dt.itemsize),
int(dtype.itemsize),
]

offset += dtype.itemsize
z.attrs.update(
zroot.attrs.update(
{
k: v.decode() if isinstance(v, bytes) else str(v)
for k, v in self._attributes.items()
Expand Down
31 changes: 18 additions & 13 deletions kerchunk/tests/test_fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
var = os.path.join(testdir, "variable_length_table.fits")


def test_ascii_table():
@pytest.mark.parametrize("zarr_version", [2])
def test_ascii_table(zarr_version):
# this one directly hits a remote server - should cache?
url = "https://fits.gsfc.nasa.gov/samples/WFPC2u5780205r_c0fx.fits"
out = kerchunk.fits.process_file(url, extension=1)
out = kerchunk.fits.process_file(url, extension=1, zarr_version=zarr_version)
m = fsspec.get_mapper("reference://", fo=out, remote_protocol="https")
g = zarr.open(m)
g = zarr.open(m, zarr_version=zarr_version)
arr = g["u5780205r_cvt.c0h.tab"][:]
with fsspec.open(
"https://fits.gsfc.nasa.gov/samples/WFPC2u5780205r_c0fx.fits"
Expand All @@ -28,10 +29,11 @@ def test_ascii_table():
assert list(hdu.data.astype(arr.dtype) == arr) == [True, True, True, True]


def test_binary_table():
out = kerchunk.fits.process_file(btable, extension=1)
@pytest.mark.parametrize("zarr_version", [2, 3])
def test_binary_table(zarr_version):
out = kerchunk.fits.process_file(btable, extension=1, zarr_version=zarr_version)
m = fsspec.get_mapper("reference://", fo=out)
z = zarr.open(m)
z = zarr.open(m, zarr_version=zarr_version)
arr = z["1"]
with open(btable, "rb") as f:
hdul = fits.open(f)
Expand All @@ -45,38 +47,41 @@ def test_binary_table():
).all() # string come out as bytes


def test_cube():
out = kerchunk.fits.process_file(range_im)
@pytest.mark.parametrize("zarr_version", [2])
def test_cube(zarr_version):
out = kerchunk.fits.process_file(range_im, zarr_version=zarr_version)
m = fsspec.get_mapper("reference://", fo=out)
z = zarr.open(m)
z = zarr.open(m, zarr_version=zarr_version)
arr = z["PRIMARY"]
with open(range_im, "rb") as f:
hdul = fits.open(f)
expected = hdul[0].data
assert (arr[:] == expected).all()


def test_with_class():
@pytest.mark.parametrize("zarr_version", [2])
def test_with_class(zarr_version):
ftz = kerchunk.fits.FitsToZarr(range_im)
out = ftz.translate()
assert "fits" in repr(ftz)
m = fsspec.get_mapper("reference://", fo=out)
z = zarr.open(m)
z = zarr.open(m, zarr_version=zarr_version)
arr = z["PRIMARY"]
with open(range_im, "rb") as f:
hdul = fits.open(f)
expected = hdul[0].data
assert (arr[:] == expected).all()


def test_var():
@pytest.mark.parametrize("zarr_version", [2])
def test_var(zarr_version):
data = fits.open(var)[1].data
expected = [_.tolist() for _ in data["var"]]

ftz = kerchunk.fits.FitsToZarr(var)
out = ftz.translate()
m = fsspec.get_mapper("reference://", fo=out)
z = zarr.open(m)
z = zarr.open(m, zarr_version=zarr_version)
arr = z["1"]
vars = [_.tolist() for _ in arr["var"]]

Expand Down
11 changes: 8 additions & 3 deletions kerchunk/tests/test_grib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@
here = os.path.dirname(__file__)


def test_one():
@pytest.mark.parametrize("zarr_version", [2])
def test_one(zarr_version):
# from https://dd.weather.gc.ca/model_gem_regional/10km/grib2/00/000
fn = os.path.join(here, "CMC_reg_DEPR_ISBL_10_ps10km_2022072000_P000.grib2")
out = scan_grib(fn)
out = scan_grib(fn, zarr_version=zarr_version)
ds = xr.open_dataset(
"reference://",
engine="zarr",
backend_kwargs={"consolidated": False, "storage_options": {"fo": out[0]}},
backend_kwargs={
"consolidated": False,
"zarr_version": zarr_version,
"storage_options": {"fo": out[0]},
},
)

assert ds.attrs["GRIB_centre"] == "cwao"
Expand Down
Loading
Loading