diff --git a/flaml/automl/model.py b/flaml/automl/model.py index 96ada5e158..ad939baad6 100644 --- a/flaml/automl/model.py +++ b/flaml/automl/model.py @@ -912,18 +912,24 @@ def search_space(cls, data_size, task, **params): If OOM, user should change the search space themselves """ - search_space_dict["model_path"] = { - "domain": tune.choice( - [ - "google/electra-base-discriminator", - "bert-base-uncased", - "roberta-base", - "facebook/muppet-roberta-base", - "google/electra-small-discriminator", - ] - ), - "init_value": "facebook/muppet-roberta-base", - } + if task not in NLG_TASKS: + search_space_dict["model_path"] = { + "domain": tune.choice( + [ + "google/electra-base-discriminator", + "bert-base-uncased", + "roberta-base", + "facebook/muppet-roberta-base", + "google/electra-small-discriminator", + ] + ), + "init_value": "facebook/muppet-roberta-base", + } + else: + search_space_dict["model_path"] = { + "domain": tune.choice(["t5-small", "facebook/bart-base"]), + "init_value": "t5-small", + } return search_space_dict diff --git a/flaml/model.py b/flaml/model.py index b780a67d16..610d6942bc 100644 --- a/flaml/model.py +++ b/flaml/model.py @@ -1,9 +1,9 @@ -import warnings - -from flaml.automl.model import * - - -warnings.warn( - "Importing from `flaml.model` is deprecated. Please use `flaml.automl.model`.", - DeprecationWarning, -) +import warnings + +from flaml.automl.model import * + + +warnings.warn( + "Importing from `flaml.model` is deprecated. Please use `flaml.automl.model`.", + DeprecationWarning, +) diff --git a/test/nlp/test_autohf_modelselection.py b/test/nlp/test_autohf_modelselection.py new file mode 100644 index 0000000000..a08267c9fd --- /dev/null +++ b/test/nlp/test_autohf_modelselection.py @@ -0,0 +1,46 @@ +import sys +import pytest +import requests +from utils import get_toy_data_summarization, get_automl_settings +import os +import shutil + + +@pytest.mark.skipif( + sys.platform == "darwin" or sys.version < "3.7", + reason="do not run on mac os or py<3.7", +) +def test_hf_ms(): + from flaml import AutoML + + X_train, y_train, X_val, y_val, X_test = get_toy_data_summarization() + + automl = AutoML() + + automl_settings = { + "gpu_per_trial": 0, + "max_iter": 3, + "time_budget": 20, + "task": "summarization", + "metric": "rouge1", + "log_file_name": "seqclass.log", + "use_ray": False, + "estimator_list": ["transformer_ms"], + } + + try: + automl.fit( + X_train=X_train, + y_train=y_train, + X_val=X_val, + y_val=y_val, + **automl_settings + ) + automl.score(X_val, y_val, **{"metric": "accuracy"}) + automl.pickle("automl.pkl") + except requests.exceptions.HTTPError: + return + + +if __name__ == "__main__": + test_hf_ms()