diff --git a/tsensor/analysis.py b/tsensor/analysis.py index abc624e..37c8269 100644 --- a/tsensor/analysis.py +++ b/tsensor/analysis.py @@ -24,7 +24,6 @@ import os import sys import traceback -import torch import inspect import hashlib @@ -456,8 +455,8 @@ def istensor(x): def _shape(v): # do we have a shape and it answers len()? Should get stuff right. - if hasattr(v, "shape") and hasattr(v.shape,"__len__"): - if isinstance(v.shape, torch.Size): + if hasattr(v, "shape") and hasattr(v.shape, "__len__"): + if v.shape.__class__.__module__ == "torch" and v.shape.__class__.__name__ == "Size": if len(v.shape)==0: return None return list(v.shape)