-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(ml-models): download models from neptune.ai
- Loading branch information
1 parent
6915cd1
commit 44aac66
Showing
7 changed files
with
111 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import logging | ||
from pathlib import Path | ||
|
||
import neptune | ||
|
||
from sketch_map_tool.config import get_config_value | ||
|
||
PROJECT = get_config_value("neptune_project") | ||
API_TOKEN = get_config_value("neptune_api_token") | ||
|
||
|
||
def init_model(id: str) -> Path: | ||
"""Initilaze model. Download model to data dir if not present.""" | ||
# TODO: | ||
# _check_id(id) | ||
|
||
data_dir = Path(get_config_value("data-dir")) | ||
model = neptune.init_model_version( | ||
with_id=id, | ||
project=PROJECT, | ||
api_token=API_TOKEN, | ||
mode="read-only", | ||
) | ||
|
||
raw = data_dir / id | ||
path = raw.with_suffix(_get_file_suffix(id)) | ||
if not path.is_file(): | ||
logging.info(f"Downloading model {id} from neptune.ai to {path}.") | ||
model["model"].download(str(path)) | ||
|
||
# TODO: check if model is valid/working | ||
logging.info("Model available model from neptune.ai: " + id) | ||
return path | ||
|
||
|
||
def _check_id(id: str): | ||
# TODO: | ||
project = neptune.init_project( | ||
project=PROJECT, | ||
api_token=API_TOKEN, | ||
mode="read-only", | ||
) | ||
|
||
if not project.exists("models/" + id): | ||
raise ValueError("Invalid model ID: " + id) | ||
|
||
|
||
def _get_file_suffix(id: str) -> str: | ||
if "SAM" in id: | ||
return ".pth" | ||
elif "OSM" in id: | ||
return ".pt" | ||
else: | ||
raise ValueError("Unexpected model ID: " + id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import pytest | ||
from hypothesis import example, given | ||
from hypothesis.strategies import text | ||
|
||
from sketch_map_tool.config import get_config_value | ||
from sketch_map_tool.upload_processing import ml_models | ||
from tests import vcr_app as vcr | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"id", | ||
( | ||
get_config_value("neptune_model_id_yolo"), | ||
get_config_value("neptune_model_id_sam"), | ||
), | ||
) | ||
@pytest.mark.skip("longrunning tests. downloads ml-models from neptunge.ai") | ||
def test_init_model(id, monkeypatch, tmpdir): | ||
monkeypatch.setenv("SMT-DATA-DIR", tmpdir) | ||
path = ml_models.init_model(id) | ||
assert path.is_file() | ||
|
||
|
||
@given(text()) | ||
@example("") | ||
@pytest.mark.skip(reason="not implemented yet") | ||
@vcr.use_cassette | ||
def test_init_model_unexpected_id(id): | ||
with pytest.raises(ValueError): | ||
ml_models.init_model(id) |