diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 6b3ddade8..2a722eec6 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -29,9 +29,9 @@ try: import _codecs import numpy as np - # add safe globals, known to be needed for metaclip weights + # add safe globals that are known to be needed for metaclip weights loading in weights_only=True mode torch.serialization.add_safe_globals([ - _codecs.encode, # now in pytorch main but some pytorch versions w/ weights_only flag don't have it + _codecs.encode, # this one not needed for PyTorch >= 2.5.0 np.core.multiarray.scalar, np.dtype, np.dtypes.Float64DType,