Skip to content

Commit e606d53

Browse files
authored
fix: tensorstore isel (#25)
* fix: tensorstore isel * fix tensorstore
1 parent 388965f commit e606d53

File tree

2 files changed

+36
-19
lines changed

2 files changed

+36
-19
lines changed

examples/tensorstore_arr.py

+7-13
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,15 @@
88

99
import ndv
1010

11-
data = ndv.data.cells3d()
12-
1311
ts_array = ts.open(
1412
{
15-
"driver": "zarr",
16-
"kvstore": {"driver": "memory"},
17-
"transform": {
18-
# tensorstore supports labeled dimensions
19-
"input_labels": ["z", "c", "y", "x"],
13+
"driver": "n5",
14+
"kvstore": {
15+
"driver": "s3",
16+
"bucket": "janelia-cosem-datasets",
17+
"path": "jrc_hela-3/jrc_hela-3.n5/labels/er-mem_pred/s4/",
2018
},
2119
},
22-
create=True,
23-
shape=data.shape,
24-
dtype=data.dtype,
2520
).result()
26-
ts_array[:] = ndv.data.cells3d()
27-
28-
ndv.imshow(ts_array)
21+
ts_array = ts_array[ts.d[:].label["z", "y", "x"]]
22+
ndv.imshow(ts_array[ts.d[("y", "x", "z")].transpose[:]])

src/ndv/viewer/_data_wrapper.py

+29-6
Original file line numberDiff line numberDiff line change
@@ -217,19 +217,42 @@ class TensorstoreWrapper(DataWrapper["ts.TensorStore"]):
217217

218218
def __init__(self, data: Any) -> None:
219219
super().__init__(data)
220+
import json
221+
220222
import tensorstore as ts
221223

222224
self._ts = ts
223225

226+
spec = self.data.spec().to_json()
227+
labels: Sequence[Hashable] | None = None
228+
self._ts = ts
229+
if (tform := spec.get("transform")) and ("input_labels" in tform):
230+
labels = [str(x) for x in tform["input_labels"]]
231+
elif (
232+
str(spec.get("driver")).startswith("zarr")
233+
and (zattrs := self.data.kvstore.read(".zattrs").result().value)
234+
and isinstance((zattr_dict := json.loads(zattrs)), dict)
235+
and "_ARRAY_DIMENSIONS" in zattr_dict
236+
):
237+
labels = zattr_dict["_ARRAY_DIMENSIONS"]
238+
239+
if isinstance(labels, Sequence) and len(labels) == len(self._data.domain):
240+
self._labels: list[Hashable] = [str(x) for x in labels]
241+
self._data = self.data[ts.d[:].label[self._labels]]
242+
else:
243+
self._labels = list(range(len(self._data.domain)))
244+
224245
def sizes(self) -> Mapping[Hashable, int]:
225-
return {dim.label: dim.size for dim in self._data.domain}
246+
return dict(zip(self._labels, self._data.domain.shape))
226247

227248
def isel(self, indexers: Indices) -> np.ndarray:
228-
result = (
229-
self._data[self._ts.d[tuple(indexers)][tuple(indexers.values())]]
230-
.read()
231-
.result()
232-
)
249+
if not indexers:
250+
slc = slice(None)
251+
else:
252+
labels, values = zip(*indexers.items())
253+
origins = (0,) * len(labels)
254+
slc = self._ts.d[labels].translate_to[origins][values]
255+
result = self._data[slc].read().result()
233256
return np.asarray(result)
234257

235258
@classmethod

0 commit comments

Comments
 (0)