diff --git a/cellpose/dynamics.py b/cellpose/dynamics.py index 33e8c3c6..83282c1e 100644 --- a/cellpose/dynamics.py +++ b/cellpose/dynamics.py @@ -578,7 +578,22 @@ def remove_bad_flow_masks(masks, flows, threshold=0.4, use_gpu=False, device=Non """ if masks.size > 10000*10000: - if masks.size * 20 > torch.cuda.mem_get_info()[0]: + + major_version, minor_version, _ = torch.__version__.split(".") + + if major_version == "1" and int(minor_version) < 10: + # for PyTorch version lower than 1.10 + def mem_info(): + total_mem = torch.cuda.get_device_properties(0).total_memory + used_mem = torch.cuda.memory_allocated() + return total_mem, used_mem + else: + # for PyTorch version 1.10 and above + def mem_info(): + total_mem, used_mem = torch.cuda.mem_get_info() + return total_mem, used_mem + + if masks.size * 20 > mem_info()[0]: dynamics_logger.warning('WARNING: image is very large, not using gpu to compute flows from masks for QC step flow_threshold') dynamics_logger.info('turn off QC step with flow_threshold=0 if too slow') use_gpu = False