diff --git a/python/hsml/util.py b/python/hsml/util.py index c47733d50..6ef6d9053 100644 --- a/python/hsml/util.py +++ b/python/hsml/util.py @@ -232,15 +232,15 @@ def validate_metrics(metrics): def get_predictor_for_model(model, **kwargs): from hsml.model import Model as BaseModel + from hsml.predictor import Predictor as BasePredictor from hsml.python.model import Model as PyModel + from hsml.python.predictor import Predictor as PyPredictor from hsml.sklearn.model import Model as SkLearnModel + from hsml.sklearn.predictor import Predictor as SkLearnPredictor from hsml.tensorflow.model import Model as TFModel + from hsml.tensorflow.predictor import Predictor as TFPredictor from hsml.torch.model import Model as TorchModel from hsml.torch.predictor import Predictor as TorchPredictor - from hsml.predictor import Predictor as BasePredictor - from hsml.tensorflow.predictor import Predictor as TFPredictor - from hsml.sklearn.predictor import Predictor as SkLearnPredictor - from hsml.python.predictor import Predictor as PyPredictor if not isinstance(model, BaseModel): raise ValueError( "model is of type {}, but an instance of {} class is expected".format( @@ -248,15 +248,15 @@ def get_predictor_for_model(model, **kwargs): ) ) - if type(model) == TFModel: + if type(model) is TFModel: return TFPredictor(**kwargs) - if type(model) == TorchModel: + if type(model) is TorchModel: return TorchPredictor(**kwargs) - if type(model) == SkLearnModel: + if type(model) is SkLearnModel: return SkLearnPredictor(**kwargs) - if type(model) == PyModel: + if type(model) is PyModel: return PyPredictor(**kwargs) - if type(model) == BaseModel: + if type(model) is BaseModel: return BasePredictor( # python as default framework and model server model_framework=MODEL.FRAMEWORK_PYTHON, model_server=PREDICTOR.MODEL_SERVER_PYTHON,