diff --git a/bindsnet/datasets/collate.py b/bindsnet/datasets/collate.py index 1707973..46518d3 100644 --- a/bindsnet/datasets/collate.py +++ b/bindsnet/datasets/collate.py @@ -8,7 +8,14 @@ """ import torch -from torch._six import container_abcs, string_classes, int_classes + +# https://github.com/pytorch/pytorch/pull/94709#issuecomment-1461471006 +try: + from torch._six import container_abcs, string_classes, int_classes +except ModuleNotFoundError: + int_classes = int + string_classes = str + import collections.abc as container_abcs from torch.utils.data._utils import collate as pytorch_collate diff --git a/bindsnet/pipeline/base_pipeline.py b/bindsnet/pipeline/base_pipeline.py index b1026c3..04c965b 100644 --- a/bindsnet/pipeline/base_pipeline.py +++ b/bindsnet/pipeline/base_pipeline.py @@ -2,7 +2,13 @@ from typing import Tuple, Dict, Any import torch -from torch._six import container_abcs, string_classes + +# https://github.com/pytorch/pytorch/pull/94709#issuecomment-1461471006 +try: + from torch._six import container_abcs, string_classes +except ModuleNotFoundError: + string_classes = str + import collections.abc as container_abcs from ..network import Network from ..network.monitors import Monitor