diff --git a/nequip/__init__.py b/nequip/__init__.py index ce145b41..2481bf56 100644 --- a/nequip/__init__.py +++ b/nequip/__init__.py @@ -8,7 +8,7 @@ import warnings # torch version checks -torch_version = packaging.version.parse(torch.__version__) +torch_version = packaging.version.parse(torch.__version__.split("+")[0]) # only allow 1.11*, 1.13* or higher (no 1.12.*) assert (torch_version == packaging.version.parse("1.11")) or (