@@ -217,19 +217,42 @@ class TensorstoreWrapper(DataWrapper["ts.TensorStore"]):
217
217
218
218
def __init__ (self , data : Any ) -> None :
219
219
super ().__init__ (data )
220
+ import json
221
+
220
222
import tensorstore as ts
221
223
222
224
self ._ts = ts
223
225
226
+ spec = self .data .spec ().to_json ()
227
+ labels : Sequence [Hashable ] | None = None
228
+ self ._ts = ts
229
+ if (tform := spec .get ("transform" )) and ("input_labels" in tform ):
230
+ labels = [str (x ) for x in tform ["input_labels" ]]
231
+ elif (
232
+ str (spec .get ("driver" )).startswith ("zarr" )
233
+ and (zattrs := self .data .kvstore .read (".zattrs" ).result ().value )
234
+ and isinstance ((zattr_dict := json .loads (zattrs )), dict )
235
+ and "_ARRAY_DIMENSIONS" in zattr_dict
236
+ ):
237
+ labels = zattr_dict ["_ARRAY_DIMENSIONS" ]
238
+
239
+ if isinstance (labels , Sequence ) and len (labels ) == len (self ._data .domain ):
240
+ self ._labels : list [Hashable ] = [str (x ) for x in labels ]
241
+ self ._data = self .data [ts .d [:].label [self ._labels ]]
242
+ else :
243
+ self ._labels = list (range (len (self ._data .domain )))
244
+
224
245
def sizes (self ) -> Mapping [Hashable , int ]:
225
- return { dim . label : dim . size for dim in self ._data .domain }
246
+ return dict ( zip ( self . _labels , self ._data .domain . shape ))
226
247
227
248
def isel (self , indexers : Indices ) -> np .ndarray :
228
- result = (
229
- self ._data [self ._ts .d [tuple (indexers )][tuple (indexers .values ())]]
230
- .read ()
231
- .result ()
232
- )
249
+ if not indexers :
250
+ slc = slice (None )
251
+ else :
252
+ labels , values = zip (* indexers .items ())
253
+ origins = (0 ,) * len (labels )
254
+ slc = self ._ts .d [labels ].translate_to [origins ][values ]
255
+ result = self ._data [slc ].read ().result ()
233
256
return np .asarray (result )
234
257
235
258
@classmethod
0 commit comments