Skip to content

Commit

Permalink
Test temporary directory cleanup (#261)
Browse files Browse the repository at this point in the history
* Proper tempfile clean up in tests

* Style fixes

* Small fix

* Final fix in arrow tests
  • Loading branch information
mariosasko authored Jan 4, 2021
1 parent 8062d3a commit 42e439b
Show file tree
Hide file tree
Showing 14 changed files with 418 additions and 469 deletions.
4 changes: 4 additions & 0 deletions make.bat
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ if "%1"!="" goto error
black --check --line-length 90 --target-version py36 podium tests examples
echo isort --check-only podium tests examples
isort --check-only podium tests examples
echo docformatter podium tests examples --check --recursive --wrap-descriptions 80 --wrap-summaries 80 --pre-summary-newline --make-summary-multi-line
docformatter podium tests examples --check --recursive --wrap-descriptions 80 --wrap-summaries 80 --pre-summary-newline --make-summary-multi-line
echo flake8 podium tests examples
flake8 podium tests examples
goto :EOF
Expand All @@ -22,6 +24,8 @@ if "%1"!="" goto error
black --line-length 90 --target-version py36 podium tests examples
echo isort podium tests examples
isort podium tests examples
echo docformatter podium tests examples -i --recursive --wrap-descriptions 80 --wrap-summaries 80 --pre-summary-newline --make-summary-multi-line
docformatter podium tests examples -i --recursive --wrap-descriptions 80 --wrap-summaries 80 --pre-summary-newline --make-summary-multi-line
goto :EOF

:test
Expand Down
6 changes: 4 additions & 2 deletions podium/arrow/arrow_tabular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def from_examples(
examples: Iterable[Example],
cache_path: str = None,
data_types: Dict[str, Tuple[pa.DataType, pa.DataType]] = None,
chunk_size=10_000,
chunk_size=1024,
) -> "ArrowDataset":
"""
Creates an ArrowDataset from the provided Examples.
Expand Down Expand Up @@ -705,8 +705,9 @@ def close(self):
if self.mmapped_file is not None:
self.mmapped_file.close()
self.mmapped_file = None
self.table = None

else: # Do nothing
else:
warnings.warn("Attempted closing an already closed ArrowDataset.")

def delete_cache(self):
Expand All @@ -715,4 +716,5 @@ def delete_cache(self):
"""
if self.mmapped_file is not None:
self.close()

shutil.rmtree(self.cache_path)
11 changes: 6 additions & 5 deletions podium/storage/resources/large_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,12 @@ def _download_unarchive(self):
Method downloades resource and decompresses it to resource location.
"""
os.makedirs(name=self.resource_location)
download_dir = os.path.join(
tempfile.mkdtemp(), self.config[LargeResource.RESOURCE_NAME]
)
self._download(download_destination=download_dir)
self._unarchive(archive_file=download_dir)
with tempfile.TemporaryDirectory() as temp_dir:
download_dir = os.path.join(
temp_dir, self.config[LargeResource.RESOURCE_NAME]
)
self._download(download_destination=download_dir)
self._unarchive(archive_file=download_dir)

def _check_args(self, arguments):
"""
Expand Down
58 changes: 35 additions & 23 deletions tests/arrow/test_pyarrow_tabular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,14 @@ def pyarrow_dataset(data, fields):

@pytest.fixture(name="pyarrow_dataset")
def pyarrow_dataset_fixture(data, fields):
return pyarrow_dataset(data, fields)
ad = pyarrow_dataset(data, fields)
yield ad
ad.delete_cache()


def test_from_examples(data, fields):
example_factory = ExampleFactory(fields)
examples = map(example_factory.from_list, iter(data))
examples = [example_factory.from_list(ex) for ex in data]
ad = ArrowDataset.from_examples(fields, examples)

for (raw, tokenized), (num, _) in zip(ad.number, data):
Expand All @@ -67,6 +69,8 @@ def test_from_examples(data, fields):
assert raw == tok
assert tokenized == tok.split(" ")

ad.delete_cache()


@pytest.mark.parametrize(
"index",
Expand Down Expand Up @@ -96,7 +100,7 @@ def test_slicing(index, data, fields, pyarrow_dataset):
assert tok_raw == tok


def test_dump_and_load(pyarrow_dataset):
def test_dump_and_load(pyarrow_dataset, tmpdir):
cache_dir = pyarrow_dataset.dump_cache(cache_path=None)
loaded_dataset = ArrowDataset.load_cache(cache_dir)

Expand All @@ -109,6 +113,8 @@ def test_dump_and_load(pyarrow_dataset):
== loaded_dataset.field_dict["tokens"].vocab.stoi
)

loaded_dataset.delete_cache()

dataset_sliced = pyarrow_dataset[8:2:-2]
cache_dir_sliced = dataset_sliced.dump_cache(cache_path=None)
loaded_dataset_sliced = ArrowDataset.load_cache(cache_dir_sliced)
Expand All @@ -122,9 +128,10 @@ def test_dump_and_load(pyarrow_dataset):
== loaded_dataset_sliced.field_dict["tokens"].vocab.stoi
)

cache_dir = tempfile.mkdtemp()
pyarrow_dataset.dump_cache(cache_path=cache_dir)
loaded_dataset = ArrowDataset.load_cache(cache_dir)
loaded_dataset_sliced.delete_cache()

pyarrow_dataset.dump_cache(cache_path=tmpdir)
loaded_dataset = ArrowDataset.load_cache(tmpdir)

assert len(loaded_dataset) == len(pyarrow_dataset)
for ex_original, ex_loaded in zip(pyarrow_dataset, loaded_dataset):
Expand All @@ -135,6 +142,8 @@ def test_dump_and_load(pyarrow_dataset):
== loaded_dataset.field_dict["tokens"].vocab.stoi
)

loaded_dataset.delete_cache()


def test_finalize_fields(data, fields, mocker):
for field in fields:
Expand All @@ -158,6 +167,8 @@ def test_finalize_fields(data, fields, mocker):
# all fields should be finalized
assert f.finalized

dataset.delete_cache()


def test_filtered(data, pyarrow_dataset):
def filter_even(ex):
Expand Down Expand Up @@ -203,25 +214,28 @@ def test_from_dataset(data, fields):
assert ds_ex.number == arrow_ex.number
assert ds_ex.tokens == arrow_ex.tokens

pyarrow_dataset.delete_cache()


@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="the reason for failure on Windows is not known at the moment",
)
def test_from_tabular(data, fields):
with tempfile.TemporaryDirectory() as tdir:
test_file = os.path.join(tdir, "test.csv")
with open(test_file, "w") as f:
writer = csv.writer(f)
writer.writerows(data)
def test_from_tabular(data, fields, tmpdir):
test_file = os.path.join(tmpdir, "test.csv")
with open(test_file, "w") as f:
writer = csv.writer(f)
writer.writerows(data)

csv_dataset = ArrowDataset.from_tabular_file(test_file, "csv", fields)
for ex, d in zip(csv_dataset, data):
assert int(ex.number[0]) == d[0]
assert ex.tokens[0] == d[1]
csv_dataset = ArrowDataset.from_tabular_file(test_file, "csv", fields)
for ex, d in zip(csv_dataset, data):
assert int(ex.number[0]) == d[0]
assert ex.tokens[0] == d[1]

csv_dataset.delete_cache()

def test_missing_datatype_exception(data, fields):

def test_missing_datatype_exception(data, fields, tmpdir):
data_null = [(*d, None) for d in data]
null_field = Field(
"null_field", keep_raw=True, allow_missing_data=True, numericalizer=Vocab()
Expand All @@ -232,7 +246,7 @@ def test_missing_datatype_exception(data, fields):
examples = map(exf.from_list, data_null)

with pytest.raises(RuntimeError):
ArrowDataset.from_examples(fields_null, examples)
ArrowDataset.from_examples(fields_null, examples, cache_path=tmpdir)


def test_datatype_definition(data, fields):
Expand All @@ -252,16 +266,14 @@ def test_datatype_definition(data, fields):
assert int(ex["number"][0]) == d[0]
assert ex["tokens"][0] == d[1]

dataset.delete_cache()


@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="shutil.rmtree has issues with the recursive removal on Windows",
)
def test_delete_cache(data, fields):
cache_dir = tempfile.mkdtemp()

example_factory = ExampleFactory(fields)
examples = map(example_factory.from_list, iter(data))
examples = map(example_factory.from_list, data)
ad = ArrowDataset.from_examples(fields, examples, cache_path=cache_dir)

assert os.path.exists(cache_dir)
Expand Down
Loading

0 comments on commit 42e439b

Please sign in to comment.