Skip to content

Commit

Permalink
Add tile catalog to region catalog conversion script (#935)
Browse files Browse the repository at this point in the history
* Add tile catalog to region catalog conversion script

* Fix import
  • Loading branch information
aakashdp6548 authored Sep 8, 2023
1 parent 2590a63 commit 88e3943
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions case_studies/adaptive_tiling/convert_tilecat_to_regioncat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import argparse
import warnings
from pathlib import Path

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from bliss.catalog import TileCatalog
from bliss.generate import itemize_data
from case_studies.adaptive_tiling.region_catalog import tile_cat_to_region_cat

parser = argparse.ArgumentParser()
parser.add_argument("input_path", type=str, help="Path to directory containing tile-based data")
parser.add_argument("output_path", type=str, help="Path to directory to store converted data")
h = "Overlap to use for conversion. Default is 0.4."
parser.add_argument("--overlap", required=False, type=float, default=0.4, help=h)
args = parser.parse_args()

OVERLAP_SLEN = args.overlap
INPUT_PATH = Path(args.input_path)
OUTPUT_PATH = Path(args.output_path)

if not OUTPUT_PATH.exists():
Path.mkdir(OUTPUT_PATH, parents=True)

warnings.filterwarnings("error") # We want the warning to throw an error to know to skip an image

# Iterate over each data file in input directory
files = list(INPUT_PATH.glob("dataset_*.pt"))
for filename in tqdm(files):
dataloader = DataLoader(torch.load(filename), batch_size=1) # get one at a time
region_data = []
skip_count = 0
for batch in dataloader:
# Try conversion. If it fails, skip this image
try:
region_cat = tile_cat_to_region_cat(TileCatalog(4, batch["tile_catalog"]), OVERLAP_SLEN)
except: # pylint:disable=bare-except # noqa
skip_count += 1
continue

batch["tile_catalog"] = region_cat.to_dict()
region_data.append(batch)

print("Skipped {skip_count} out of {len(dataloader)} total images") # noqa: WPS421

# Save converted data to new file
output_filename = OUTPUT_PATH / filename.parts[-1]
with open(output_filename, "wb") as f:
torch.save(itemize_data(region_data), f)

0 comments on commit 88e3943

Please sign in to comment.