-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tile catalog to region catalog conversion script (#935)
* Add tile catalog to region catalog conversion script * Fix import
- Loading branch information
1 parent
2590a63
commit 88e3943
Showing
1 changed file
with
51 additions
and
0 deletions.
There are no files selected for viewing
51 changes: 51 additions & 0 deletions
51
case_studies/adaptive_tiling/convert_tilecat_to_regioncat.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import argparse | ||
import warnings | ||
from pathlib import Path | ||
|
||
import torch | ||
from torch.utils.data import DataLoader | ||
from tqdm import tqdm | ||
|
||
from bliss.catalog import TileCatalog | ||
from bliss.generate import itemize_data | ||
from case_studies.adaptive_tiling.region_catalog import tile_cat_to_region_cat | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("input_path", type=str, help="Path to directory containing tile-based data") | ||
parser.add_argument("output_path", type=str, help="Path to directory to store converted data") | ||
h = "Overlap to use for conversion. Default is 0.4." | ||
parser.add_argument("--overlap", required=False, type=float, default=0.4, help=h) | ||
args = parser.parse_args() | ||
|
||
OVERLAP_SLEN = args.overlap | ||
INPUT_PATH = Path(args.input_path) | ||
OUTPUT_PATH = Path(args.output_path) | ||
|
||
if not OUTPUT_PATH.exists(): | ||
Path.mkdir(OUTPUT_PATH, parents=True) | ||
|
||
warnings.filterwarnings("error") # We want the warning to throw an error to know to skip an image | ||
|
||
# Iterate over each data file in input directory | ||
files = list(INPUT_PATH.glob("dataset_*.pt")) | ||
for filename in tqdm(files): | ||
dataloader = DataLoader(torch.load(filename), batch_size=1) # get one at a time | ||
region_data = [] | ||
skip_count = 0 | ||
for batch in dataloader: | ||
# Try conversion. If it fails, skip this image | ||
try: | ||
region_cat = tile_cat_to_region_cat(TileCatalog(4, batch["tile_catalog"]), OVERLAP_SLEN) | ||
except: # pylint:disable=bare-except # noqa | ||
skip_count += 1 | ||
continue | ||
|
||
batch["tile_catalog"] = region_cat.to_dict() | ||
region_data.append(batch) | ||
|
||
print("Skipped {skip_count} out of {len(dataloader)} total images") # noqa: WPS421 | ||
|
||
# Save converted data to new file | ||
output_filename = OUTPUT_PATH / filename.parts[-1] | ||
with open(output_filename, "wb") as f: | ||
torch.save(itemize_data(region_data), f) |