Skip to content

Commit

Permalink
update logic for when dataset already exists locally. move unzipping …
Browse files Browse the repository at this point in the history
…into format_dataset
  • Loading branch information
Kory Stiger committed Jul 29, 2021
1 parent 81af1f2 commit 9a5080e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 43 deletions.
67 changes: 36 additions & 31 deletions test/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import json

import zpy.client as zpy
import unittest

from zpy.client_util import remove_n_extensions


def test_1(**init_kwargs):
Expand All @@ -26,17 +29,6 @@ def test_2(**init_kwargs):
print(json.dumps(urls, indent=4, sort_keys=True))


def test_3():
""""""
zpy.init(
project_uuid="feb6e594-55e0-4f87-9e75-5a128221499f",
auth_token="a4a13763b0dc0017b1fc9af890e9efea58fd072074ab9a169e5dcf0633310f28",
)
dataset_config = zpy.DatasetConfig("dumpster_v5.1")
dataset_config.set("run\.padding_style", "messy")
zpy.generate(dataset_config, num_datapoints=3)


def pretty_print(object):
try:
json.dumps(object)
Expand Down Expand Up @@ -64,27 +56,42 @@ def datapoint_callback(images, annotations, categories):


# https://docs.python.org/3/library/unittest.html#module-unittest
# class TestClientUtilMethods(unittest.TestCase):
# def test_remove_n_extensions(self):
# self.assertTrue("/foo" == remove_n_extensions("/foo.rgb.png", 2))
# self.assertTrue("/images" == remove_n_extensions("/images.foo.rgb.png", 3))
# self.assertTrue("/images.rgb" == remove_n_extensions("/images.rgb.png", 1))
# self.assertTrue(
# "/foo/images" == remove_n_extensions("/foo/images.rgb.png", 9001)
# )
#
# def test_hash(self):
# dictA = hash({"foo": 1, "bar": 2})
# dictB = hash({"bar": 2, "foo": 1})
# self.assertEqual(hash(dictA), hash(dictB))
# self.assertEqual(hash(True), hash(True))
# self.assertNotEqual(hash(True), hash(False))
# self.assertNotEqual(hash(1), hash(2))
# self.assertNotEqual(hash([1]), hash([1, 1]))
class TestClientUtilMethods(unittest.TestCase):
def test_remove_n_extensions(self):
self.assertTrue("/foo" == remove_n_extensions("/foo.rgb.png", 2))
self.assertTrue("/images" == remove_n_extensions("/images.foo.rgb.png", 3))
self.assertTrue("/images.rgb" == remove_n_extensions("/images.rgb.png", 1))
self.assertTrue(
"/foo/images" == remove_n_extensions("/foo/images.rgb.png", 9001)
)

def test_hash(self):
dictA = hash({"foo": 1, "bar": 2})
dictB = hash({"bar": 2, "foo": 1})
self.assertEqual(hash(dictA), hash(dictB))
self.assertEqual(hash(True), hash(True))
self.assertNotEqual(hash(True), hash(False))
self.assertNotEqual(hash(1), hash(2))
self.assertNotEqual(hash([1]), hash([1, 1]))

def test_generate(self):
zpy.init(
project_uuid='feb6e594-55e0-4f87-9e75-5a128221499f',
auth_token='a4a13763b0dc0017b1fc9af890e9efea58fd072074ab9a169e5dcf0633310f28',
)
dataset_config = zpy.DatasetConfig("dumpster_v5.1")
dataset_config.set("run\.padding_style", "messy")

def datapoint_callback(images, annotations, categories):
pretty_print(images)
pretty_print(annotations)
pretty_print(categories)

zpy.generate(dataset_config, num_datapoints=3, datapoint_callback=datapoint_callback)


if __name__ == "__main__":
# unittest.main()
unittest.main()
# init_kwargs = {
# "base_url": "http://localhost:8000",
# "project_uuid": "aad8e2b2-5431-4104-a205-dc3b638b0dab",
Expand All @@ -109,8 +116,6 @@ def datapoint_callback(images, annotations, categories):
# test_1(**init_kwargs)
# print("Running test_2:")
# test_2(**init_kwargs)
# print("Running test_3:")
test_3()
# test format dataset

# def datapoint_callback(images, annotations, categories):
Expand Down
13 changes: 6 additions & 7 deletions zpy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
clear_last_print,
is_done,
format_dataset,
dict_hash,
dict_hash, remove_n_extensions,
)

_auth_token: str = ""
Expand Down Expand Up @@ -293,20 +293,19 @@ def generate(
f"{str(dataset['name']).replace(' ', '_')}-{dataset['id'][:8]}.zip"
)
# Throw it in /tmp for now I guess
output_path = join(DATASET_OUTPUT_PATH, name_slug)
output_path = Path(DATASET_OUTPUT_PATH) / name_slug
existing_files = listdir(DATASET_OUTPUT_PATH)
if name_slug not in existing_files:
print(
f"Downloading {convert_size(dataset_download_res['size_bytes'])} dataset to {output_path}"
)
download_url(dataset_download_res["redirect_link"], output_path)
unzipped_dataset_path = extract_zip(output_path)
format_dataset(unzipped_dataset_path, datapoint_callback)
format_dataset(output_path, datapoint_callback)
print("Done.")
elif datapoint_callback is not None:
format_dataset(output_path, datapoint_callback)
else:
print(
f"Download failed. Dataset {name_slug} already exists in {output_path}."
)
print(f"Dataset {name_slug} already exists in {output_path}.")

else:
print(
Expand Down
16 changes: 11 additions & 5 deletions zpy/client_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ def group_metadata_by_datapoint(
for batch in listdir(dataset_path):
batch_uri = join(dataset_path, batch)
annotation_file_uri = join(batch_uri, "_annotations.zumo.json")
metadata = json.load(open(annotation_file_uri))

with open(annotation_file_uri) as annotation_file:
metadata = json.load(annotation_file)

for c in values(metadata["categories"]):
category_count_sums[c["id"]] += c["count"]
Expand Down Expand Up @@ -299,18 +301,22 @@ def group_metadata_by_datapoint(
return accum_metadata, accum_categories, accum_datapoints


def format_dataset(dataset_path: Union[str, Path], datapoint_callback=None) -> None:
def format_dataset(zipped_dataset_path: Union[str, Path], datapoint_callback=None) -> None:
"""
Updates metadata with new ids and accurate image paths.
If a datapoint_callback is provided, it is called once per datapoint with the updated metadata.
Otherwise the default is to write out an updated _annotations.zumo.json, along with all images, to a new adjacent folder.
Args:
dataset_path (str): Path to unzipped dataset.
zipped_dataset_path (str): Path to unzipped dataset.
datapoint_callback (Callable) -> None: User defined function.
Returns:
None: No return value.
"""
metadata, categories, datapoints = group_metadata_by_datapoint(dataset_path)
unzipped_dataset_path = Path(remove_n_extensions(zipped_dataset_path, n=1))
if not unzipped_dataset_path.exists():
unzipped_dataset_path = extract_zip(zipped_dataset_path)

metadata, categories, datapoints = group_metadata_by_datapoint(unzipped_dataset_path)

if datapoint_callback is not None:
for datapoint in datapoints:
Expand All @@ -319,7 +325,7 @@ def format_dataset(dataset_path: Union[str, Path], datapoint_callback=None) -> N
)

else:
output_dir = join(dataset_path.parent, dataset_path.name + "_formatted")
output_dir = join(unzipped_dataset_path.parent, unzipped_dataset_path.name + "_formatted")

accum_metadata = {
"metadata": {
Expand Down

0 comments on commit 9a5080e

Please sign in to comment.