From 44aac66fa386bade84bfac63520a4260433231cd Mon Sep 17 00:00:00 2001 From: Matthias Schaub Date: Tue, 14 Nov 2023 16:05:18 +0100 Subject: [PATCH] feat(ml-models): download models from neptune.ai --- config/sample.config.toml | 4 ++ docs/configuration.md | 4 ++ docs/development-setup.md | 6 +++ sketch_map_tool/config.py | 10 +++- .../upload_processing/ml_models.py | 54 +++++++++++++++++++ tests/unit/test_config.py | 4 ++ tests/unit/test_ml_models.py | 30 +++++++++++ 7 files changed, 111 insertions(+), 1 deletion(-) create mode 100644 sketch_map_tool/upload_processing/ml_models.py create mode 100644 tests/unit/test_ml_models.py diff --git a/config/sample.config.toml b/config/sample.config.toml index 4f71e25a..1e7e03cd 100644 --- a/config/sample.config.toml +++ b/config/sample.config.toml @@ -7,3 +7,7 @@ wms-url = "https://maps.heigit.org/osm-carto/service?SERVICE=WMS&VERSION=1.1.1" wms-layers = "heigit:osm-carto@2xx" wms-read-timeout = 600 max-nr-simultaneous-uploads = 100 +neptune_api_token = "h0dHBzOi8aHR06E0Z...jMifQ" +neptune_project = "HeiGIT/SketchMapTool" +neptune_model_id_yolo = "SMT-OSM-1" +neptune_model_id_sam = "SMT-SAM-1" diff --git a/docs/configuration.md b/docs/configuration.md index bbc01cdf..e5838a63 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -15,3 +15,7 @@ To create a new configuration file simply copy the sample configuration file and ``` cp sample.config.toml config.toml ``` + +## Required Configuration + +Except of the API token (`SMT-NEPTUNE-API-TOKEN`) for neptune.ai all configuration values come with defaults for development purposes. Please make sure to configure the API token for your environment. diff --git a/docs/development-setup.md b/docs/development-setup.md index 88902cde..6f6c060f 100644 --- a/docs/development-setup.md +++ b/docs/development-setup.md @@ -47,6 +47,8 @@ npm run build # build/bundle JS and CSS Please refer to the [configuration documentation](/docs/configuration.md). +> TL;DR: Except of the API token (`SMT-NEPTUNE-API-TOKEN`) for neptune.ai all configuration values come with defaults for development purposes. Please make sure to configure the API token for your environment. + ## Usage ### 1. Start Celery (Task Queue) @@ -104,3 +106,7 @@ If you setup sketch-map-tool in an IDE like PyCharm please make sure that your I Go thought the setup steps above in the terminal and change interpreter settings in the IDE to point to the mamba/conda environment. Also make sure the environment variable `PROJ_LIB` to point to the `proj` directory of the mamba/conda environment. + +## Troubleshooting + +Make sure that Poetry does not try to manage the virtual environment. Check with `poetry env list`. If any environment are listed remove them: `poetry env remove ...` diff --git a/sketch_map_tool/config.py b/sketch_map_tool/config.py index 5f357709..80e2cc0c 100644 --- a/sketch_map_tool/config.py +++ b/sketch_map_tool/config.py @@ -18,7 +18,7 @@ def get_config_path() -> str: return os.getenv("SMT-CONFIG", default=default) -def load_config_default() -> Dict[str, str]: +def load_config_default() -> Dict[str, str | int | float]: return { "data-dir": get_default_data_dir(), "user-agent": "sketch-map-tool", @@ -29,6 +29,10 @@ def load_config_default() -> Dict[str, str]: "wms-read-timeout": 600, "max-nr-simultaneous-uploads": 100, "max_pixel_per_image": 10e8, # 10.000*10.000 + "neptune_project": "HeiGIT/SketchMapTool", + "neptune_api_token": "", + "neptune_model_id_yolo": "SMT-OSM-1", + "neptune_model_id_sam": "SMT-SAM-1", } @@ -53,6 +57,10 @@ def load_config_from_env() -> Dict[str, str]: "wms-read-timeout": os.getenv("SMT-WMS-READ-TIMEOUT"), "max-nr-simultaneous-uploads": os.getenv("SMT-MAX-NR-SIM-UPLOADS"), "max_pixel_per_image": os.getenv("MAX-PIXEL-PER-IMAGE"), + "neptune_project": os.getenv("SMT-NEPTUNE-PROJECT"), + "neptune_api_token": os.getenv("SMT-NEPTUNE-API-TOKEN"), + "neptune_model_id_yolo": os.getenv("SMT-NEPTUNE-MODEL-ID-YOLO"), + "neptune_model_id_sam": os.getenv("SMT-NEPTUNE-MODEL-ID-SAM"), } return {k: v for k, v in cfg.items() if v is not None} diff --git a/sketch_map_tool/upload_processing/ml_models.py b/sketch_map_tool/upload_processing/ml_models.py new file mode 100644 index 00000000..82522a1c --- /dev/null +++ b/sketch_map_tool/upload_processing/ml_models.py @@ -0,0 +1,54 @@ +import logging +from pathlib import Path + +import neptune + +from sketch_map_tool.config import get_config_value + +PROJECT = get_config_value("neptune_project") +API_TOKEN = get_config_value("neptune_api_token") + + +def init_model(id: str) -> Path: + """Initilaze model. Download model to data dir if not present.""" + # TODO: + # _check_id(id) + + data_dir = Path(get_config_value("data-dir")) + model = neptune.init_model_version( + with_id=id, + project=PROJECT, + api_token=API_TOKEN, + mode="read-only", + ) + + raw = data_dir / id + path = raw.with_suffix(_get_file_suffix(id)) + if not path.is_file(): + logging.info(f"Downloading model {id} from neptune.ai to {path}.") + model["model"].download(str(path)) + + # TODO: check if model is valid/working + logging.info("Model available model from neptune.ai: " + id) + return path + + +def _check_id(id: str): + # TODO: + project = neptune.init_project( + project=PROJECT, + api_token=API_TOKEN, + mode="read-only", + ) + + if not project.exists("models/" + id): + raise ValueError("Invalid model ID: " + id) + + +def _get_file_suffix(id: str) -> str: + if "SAM" in id: + return ".pth" + elif "OSM" in id: + return ".pt" + else: + raise ValueError("Unexpected model ID: " + id) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index a45125df..4681fea8 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -28,6 +28,10 @@ def config_keys(): "wms-read-timeout", "max-nr-simultaneous-uploads", "max_pixel_per_image", + "neptune_project", + "neptune_api_token", + "neptune_model_id_yolo", + "neptune_model_id_sam", ) diff --git a/tests/unit/test_ml_models.py b/tests/unit/test_ml_models.py new file mode 100644 index 00000000..387d285a --- /dev/null +++ b/tests/unit/test_ml_models.py @@ -0,0 +1,30 @@ +import pytest +from hypothesis import example, given +from hypothesis.strategies import text + +from sketch_map_tool.config import get_config_value +from sketch_map_tool.upload_processing import ml_models +from tests import vcr_app as vcr + + +@pytest.mark.parametrize( + "id", + ( + get_config_value("neptune_model_id_yolo"), + get_config_value("neptune_model_id_sam"), + ), +) +@pytest.mark.skip("longrunning tests. downloads ml-models from neptunge.ai") +def test_init_model(id, monkeypatch, tmpdir): + monkeypatch.setenv("SMT-DATA-DIR", tmpdir) + path = ml_models.init_model(id) + assert path.is_file() + + +@given(text()) +@example("") +@pytest.mark.skip(reason="not implemented yet") +@vcr.use_cassette +def test_init_model_unexpected_id(id): + with pytest.raises(ValueError): + ml_models.init_model(id)