Skip to content

Commit

Permalink
Changes to use doctr on AWS Lambda (mindee#1017)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtvch authored Sep 1, 2022
1 parent a3513b0 commit ea19161
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ Supported datasets
using_doctr/using_datasets
using_doctr/sharing_models
using_doctr/using_model_export
using_doctr/running_on_aws


.. toctree::
Expand Down
7 changes: 7 additions & 0 deletions docs/source/using_doctr/running_on_aws.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
AWS Lambda
========================

AWS Lambda's (read more about Lambda https://aws.amazon.com/lambda/) security policy does not allow you to write anywhere outside `/tmp` directory.
There are two things you need to do to make `doctr` work on lambda:
1. Disable usage of `multiprocessing` package by setting `DOCTR_MULTIPROCESSING_DISABLE` enivronment variable to `TRUE`. You need to do this, because this package uses `/dev/shm` directory for shared memory.
2. Change directory `doctr` uses for caching models. By default it's `~/.cache/doctr` which is outside of `/tmp` on AWS Lambda'. You can do this by setting `DOCTR_CACHE_DIR` enivronment variable.
6 changes: 5 additions & 1 deletion doctr/datasets/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,12 @@ def __init__(
cache_subdir: Optional[str] = None,
**kwargs: Any,
) -> None:
cache_dir = (
str(os.environ.get("DOCTR_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "doctr")))
if cache_dir is None
else cache_dir
)

cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "doctr") if cache_dir is None else cache_dir
cache_subdir = "datasets" if cache_subdir is None else cache_subdir

file_name = file_name if isinstance(file_name, str) else os.path.basename(url)
Expand Down
25 changes: 21 additions & 4 deletions doctr/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,19 @@ def download_from_url(
Returns:
the location of the downloaded file
Note:
You can change cache directory location by using `DOCTR_CACHE_DIR` environment variable.
"""

if not isinstance(file_name, str):
file_name = url.rpartition("/")[-1].split("&")[0]

if not isinstance(cache_dir, str):
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "doctr")
cache_dir = (
str(os.environ.get("DOCTR_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "doctr")))
if cache_dir is None
else cache_dir
)

# Check hash in file name
if hash_prefix is None:
Expand All @@ -84,8 +90,19 @@ def download_from_url(
logging.info(f"Using downloaded & verified file: {file_path}")
return file_path

# Create folder hierarchy
folder_path.mkdir(parents=True, exist_ok=True)
try:
# Create folder hierarchy
folder_path.mkdir(parents=True, exist_ok=True)
except OSError:
error_message = f"Failed creating cache direcotry at {folder_path}"
if os.environ.get("DOCTR_CACHE_DIR", ""):
error_message += " using path from 'DOCTR_CACHE_DIR' environment variable."
else:
error_message += (
". You can change default cache directory using 'DOCTR_CACHE_DIR' environment variable if needed."
)
logging.error(error_message)
raise
# Download the file
try:
print(f"Downloading {url} to {file_path}")
Expand Down
10 changes: 9 additions & 1 deletion doctr/utils/multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@


import multiprocessing as mp
import os
from multiprocessing.pool import ThreadPool
from typing import Any, Callable, Iterable, Iterator, Optional

from doctr.file_utils import ENV_VARS_TRUE_VALUES

__all__ = ["multithread_exec"]


Expand All @@ -25,11 +28,16 @@ def multithread_exec(func: Callable[[Any], Any], seq: Iterable[Any], threads: Op
Returns:
iterator of the function's results using the iterable as inputs
Notes:
This function uses ThreadPool from multiprocessing package, which uses `/dev/shm` directory for shared memory.
If you do not have write permissions for this directory (if you run `doctr` on AWS Lambda for instance),
you might want to disable multiprocessing. To achieve that, set 'DOCTR_MULTIPROCESSING_DISABLE' to 'TRUE'.
"""

threads = threads if isinstance(threads, int) else min(16, mp.cpu_count())
# Single-thread
if threads < 2:
if threads < 2 or os.environ.get("DOCTR_MULTIPROCESSING_DISABLE", "").upper() in ENV_VARS_TRUE_VALUES:
results = map(func, seq)
# Multi-threading
else:
Expand Down
46 changes: 46 additions & 0 deletions tests/common/test_utils_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
from pathlib import PosixPath
from unittest.mock import patch

import pytest

from doctr.utils.data import download_from_url


@patch("doctr.utils.data._urlretrieve")
@patch("pathlib.Path.mkdir")
@patch.dict(os.environ, {"HOME": "/"}, clear=True)
def test_download_from_url(mkdir_mock, urlretrieve_mock):
download_from_url("test_url")
urlretrieve_mock.assert_called_with("test_url", PosixPath("/.cache/doctr/test_url"))


@patch.dict(os.environ, {"DOCTR_CACHE_DIR": "/test"}, clear=True)
@patch("doctr.utils.data._urlretrieve")
@patch("pathlib.Path.mkdir")
def test_download_from_url_customizing_cache_dir(mkdir_mock, urlretrieve_mock):
download_from_url("test_url")
urlretrieve_mock.assert_called_with("test_url", PosixPath("/test/test_url"))


@patch.dict(os.environ, {"HOME": "/"}, clear=True)
@patch("pathlib.Path.mkdir", side_effect=OSError)
@patch("logging.error")
def test_download_from_url_error_creating_directory(logging_mock, mkdir_mock):
with pytest.raises(OSError):
download_from_url("test_url")
logging_mock.assert_called_with(
"Failed creating cache direcotry at /.cache/doctr."
" You can change default cache directory using 'DOCTR_CACHE_DIR' environment variable if needed."
)


@patch.dict(os.environ, {"HOME": "/", "DOCTR_CACHE_DIR": "/test"}, clear=True)
@patch("pathlib.Path.mkdir", side_effect=OSError)
@patch("logging.error")
def test_download_from_url_error_creating_directory_with_env_var(logging_mock, mkdir_mock):
with pytest.raises(OSError):
download_from_url("test_url")
logging_mock.assert_called_with(
"Failed creating cache direcotry at /test using path from 'DOCTR_CACHE_DIR' environment variable."
)
11 changes: 11 additions & 0 deletions tests/common/test_utils_multithreading.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import os
from multiprocessing.pool import ThreadPool
from unittest.mock import patch

import pytest

from doctr.utils.multithreading import multithread_exec
Expand All @@ -18,3 +22,10 @@
def test_multithread_exec(input_seq, func, output_seq):
assert list(multithread_exec(func, input_seq)) == output_seq
assert list(multithread_exec(func, input_seq, 0)) == output_seq


@patch.dict(os.environ, {"DOCTR_MULTIPROCESSING_DISABLE": "TRUE"}, clear=True)
def test_multithread_exec_multiprocessing_disable():
with patch.object(ThreadPool, "map") as mock_tp_map:
multithread_exec(lambda x: x, [1, 2])
assert not mock_tp_map.called

0 comments on commit ea19161

Please sign in to comment.