diff --git a/detecto/core.py b/detecto/core.py index 7c0b081..cc599a5 100644 --- a/detecto/core.py +++ b/detecto/core.py @@ -594,7 +594,7 @@ def save(self, file): torch.save(self._model.state_dict(), file) @staticmethod - def load(file, classes): + def load(file, classes, model_name=DEFAULT): """Loads a model from a .pth file containing the model weights. :param file: The path to the .pth file containing the saved model. @@ -613,7 +613,7 @@ def load(file, classes): >>> model = Model.load('model_weights.pth', ['ant', 'bee']) """ - model = Model(classes) + model = Model(classes, model_name=model_name) model._model.load_state_dict(torch.load(file, map_location=model._device)) return model