2121import numpy as np
2222import torch
2323
24- from monai .utils import ensure_tuple_size , get_package_version , optional_import , require_pkg , version_geq
24+ from monai .utils import (
25+ deprecated_arg ,
26+ ensure_tuple_size ,
27+ get_package_version ,
28+ optional_import ,
29+ require_pkg ,
30+ version_geq ,
31+ )
2532
2633if TYPE_CHECKING :
2734 import zarr
@@ -218,15 +225,41 @@ class ZarrAvgMerger(Merger):
218225 store: the zarr store to save the final results. Default is "merged.zarr".
219226 value_store: the zarr store to save the value aggregating tensor. Default is a temporary store.
220227 count_store: the zarr store to save the sample counting tensor. Default is a temporary store.
221- compressor: the compressor for final merged zarr array. Default is "default".
228+ compressor: the compressor for final merged zarr array. Default is None.
229+ Deprecated since 1.5.0 and will be removed in 1.7.0. Use codecs instead.
222230 value_compressor: the compressor for value aggregating zarr array. Default is None.
231+ Deprecated since 1.5.0 and will be removed in 1.7.0. Use value_codecs instead.
223232 count_compressor: the compressor for sample counting zarr array. Default is None.
233+ Deprecated since 1.5.0 and will be removed in 1.7.0. Use count_codecs instead.
234+ codecs: the codecs for final merged zarr array. Default is None.
235+ For zarr v3, this is a list of codec configurations. See zarr documentation for details.
236+ value_codecs: the codecs for value aggregating zarr array. Default is None.
237+ For zarr v3, this is a list of codec configurations. See zarr documentation for details.
238+ count_codecs: the codecs for sample counting zarr array. Default is None.
239+ For zarr v3, this is a list of codec configurations. See zarr documentation for details.
224240 chunks : int or tuple of ints that defines the chunk shape, or boolean. Default is True.
225241 If True, chunk shape will be guessed from `shape` and `dtype`.
226242 If False, it will be set to `shape`, i.e., single chunk for the whole array.
227243 If an int, the chunk size in each dimension will be given by the value of `chunks`.
228244 """
229245
246+ @deprecated_arg (
247+ name = "compressor" , since = "1.5.0" , removed = "1.7.0" , new_name = "codecs" , msg_suffix = "Please use 'codecs' instead."
248+ )
249+ @deprecated_arg (
250+ name = "value_compressor" ,
251+ since = "1.5.0" ,
252+ removed = "1.7.0" ,
253+ new_name = "value_codecs" ,
254+ msg_suffix = "Please use 'value_codecs' instead." ,
255+ )
256+ @deprecated_arg (
257+ name = "count_compressor" ,
258+ since = "1.5.0" ,
259+ removed = "1.7.0" ,
260+ new_name = "count_codecs" ,
261+ msg_suffix = "Please use 'count_codecs' instead." ,
262+ )
230263 def __init__ (
231264 self ,
232265 merged_shape : Sequence [int ],
@@ -240,6 +273,9 @@ def __init__(
240273 compressor : str | None = None ,
241274 value_compressor : str | None = None ,
242275 count_compressor : str | None = None ,
276+ codecs : list | None = None ,
277+ value_codecs : list | None = None ,
278+ count_codecs : list | None = None ,
243279 chunks : Sequence [int ] | bool = True ,
244280 thread_locking : bool = True ,
245281 ) -> None :
@@ -251,7 +287,11 @@ def __init__(
251287 self .count_dtype = count_dtype
252288 self .store = store
253289 self .tmpdir : TemporaryDirectory | None
254- if version_geq (get_package_version ("zarr" ), "3.0.0" ):
290+
291+ # Handle zarr v3 vs older versions
292+ is_zarr_v3 = version_geq (get_package_version ("zarr" ), "3.0.0" )
293+
294+ if is_zarr_v3 :
255295 if value_store is None :
256296 self .tmpdir = TemporaryDirectory ()
257297 self .value_store = zarr .storage .LocalStore (self .tmpdir .name ) # type: ignore
@@ -266,34 +306,119 @@ def __init__(
266306 self .tmpdir = None
267307 self .value_store = zarr .storage .TempStore () if value_store is None else value_store # type: ignore
268308 self .count_store = zarr .storage .TempStore () if count_store is None else count_store # type: ignore
309+
269310 self .chunks = chunks
270- self .compressor = compressor
271- self .value_compressor = value_compressor
272- self .count_compressor = count_compressor
273- self .output = zarr .empty (
274- shape = self .merged_shape ,
275- chunks = self .chunks ,
276- dtype = self .output_dtype ,
277- compressor = self .compressor ,
278- store = self .store ,
279- overwrite = True ,
280- )
281- self .values = zarr .zeros (
282- shape = self .merged_shape ,
283- chunks = self .chunks ,
284- dtype = self .value_dtype ,
285- compressor = self .value_compressor ,
286- store = self .value_store ,
287- overwrite = True ,
288- )
289- self .counts = zarr .zeros (
290- shape = self .merged_shape ,
291- chunks = self .chunks ,
292- dtype = self .count_dtype ,
293- compressor = self .count_compressor ,
294- store = self .count_store ,
295- overwrite = True ,
296- )
311+
312+ # Handle compressor/codecs based on zarr version
313+ is_zarr_v3 = version_geq (get_package_version ("zarr" ), "3.0.0" )
314+
315+ # Initialize codecs/compressor attributes with proper types
316+ self .codecs : list | None = None
317+ self .value_codecs : list | None = None
318+ self .count_codecs : list | None = None
319+
320+ if is_zarr_v3 :
321+ # For zarr v3, use codecs or convert compressor to codecs
322+ if codecs is not None :
323+ self .codecs = codecs
324+ elif compressor is not None :
325+ # Convert compressor to codec format
326+ if isinstance (compressor , (list , tuple )):
327+ self .codecs = compressor
328+ else :
329+ self .codecs = [compressor ]
330+ else :
331+ self .codecs = None
332+
333+ if value_codecs is not None :
334+ self .value_codecs = value_codecs
335+ elif value_compressor is not None :
336+ if isinstance (value_compressor , (list , tuple )):
337+ self .value_codecs = value_compressor
338+ else :
339+ self .value_codecs = [value_compressor ]
340+ else :
341+ self .value_codecs = None
342+
343+ if count_codecs is not None :
344+ self .count_codecs = count_codecs
345+ elif count_compressor is not None :
346+ if isinstance (count_compressor , (list , tuple )):
347+ self .count_codecs = count_compressor
348+ else :
349+ self .count_codecs = [count_compressor ]
350+ else :
351+ self .count_codecs = None
352+ else :
353+ # For zarr v2, use compressors
354+ if codecs is not None :
355+ # If codecs are specified in v2, use the first codec as compressor
356+ self .codecs = codecs [0 ] if isinstance (codecs , (list , tuple )) else codecs
357+ else :
358+ self .codecs = compressor # type: ignore[assignment]
359+
360+ if value_codecs is not None :
361+ self .value_codecs = value_codecs [0 ] if isinstance (value_codecs , (list , tuple )) else value_codecs
362+ else :
363+ self .value_codecs = value_compressor # type: ignore[assignment]
364+
365+ if count_codecs is not None :
366+ self .count_codecs = count_codecs [0 ] if isinstance (count_codecs , (list , tuple )) else count_codecs
367+ else :
368+ self .count_codecs = count_compressor # type: ignore[assignment]
369+
370+ # Create zarr arrays with appropriate parameters based on version
371+ if is_zarr_v3 :
372+ self .output = zarr .empty (
373+ shape = self .merged_shape ,
374+ chunks = self .chunks ,
375+ dtype = self .output_dtype ,
376+ codecs = self .codecs ,
377+ store = self .store ,
378+ overwrite = True ,
379+ )
380+ self .values = zarr .zeros (
381+ shape = self .merged_shape ,
382+ chunks = self .chunks ,
383+ dtype = self .value_dtype ,
384+ codecs = self .value_codecs ,
385+ store = self .value_store ,
386+ overwrite = True ,
387+ )
388+ self .counts = zarr .zeros (
389+ shape = self .merged_shape ,
390+ chunks = self .chunks ,
391+ dtype = self .count_dtype ,
392+ codecs = self .count_codecs ,
393+ store = self .count_store ,
394+ overwrite = True ,
395+ )
396+ else :
397+ self .output = zarr .empty (
398+ shape = self .merged_shape ,
399+ chunks = self .chunks ,
400+ dtype = self .output_dtype ,
401+ compressor = self .codecs ,
402+ store = self .store ,
403+ overwrite = True ,
404+ )
405+ self .values = zarr .zeros (
406+ shape = self .merged_shape ,
407+ chunks = self .chunks ,
408+ dtype = self .value_dtype ,
409+ compressor = self .value_codecs ,
410+ store = self .value_store ,
411+ overwrite = True ,
412+ )
413+ self .counts = zarr .zeros (
414+ shape = self .merged_shape ,
415+ chunks = self .chunks ,
416+ dtype = self .count_dtype ,
417+ compressor = self .count_codecs ,
418+ store = self .count_store ,
419+ overwrite = True ,
420+ )
421+
297422 self .lock : threading .Lock | nullcontext
298423 if thread_locking :
299424 # use lock to protect the in-place addition during aggregation
0 commit comments