Skip to content

Commit

Permalink
update for torch COCO data
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Feb 23, 2024
1 parent 093140f commit 8690a13
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions keras_cv_attention_models/coco/torch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def collate_wrapper(batch):
return torch.stack(images), torch.concat([batch_ids[:, None], bboxes, labels[:, None]], dim=-1)


def init_dataset(data_path, batch_size=64, image_size=640, num_workers=8):
def init_dataset(data_path, batch_size=64, image_size=640, num_workers=8, with_info=False):
"""
>>> os.environ["KECAM_BACKEND"] = "torch"
>>> from keras_cv_attention_models.coco import torch_data
Expand All @@ -276,7 +276,7 @@ def init_dataset(data_path, batch_size=64, image_size=640, num_workers=8):
>>> ax = show_image_with_bboxes(image, label[:, 1:-1], label[:, -1], indices_2_labels={0: 'cat', 1: 'dog'})
>>> ax.get_figure().savefig('aa.jpg')
"""
train, test = load_from_custom_json(data_path)
train, test, total_images, num_classes = load_from_custom_json(data_path, with_info=with_info)

train_dataset = DetectionDataset(train, is_train=True, image_size=image_size)
train_dataloader = DataLoader(
Expand All @@ -288,4 +288,4 @@ def init_dataset(data_path, batch_size=64, image_size=640, num_workers=8):
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_wrapper, pin_memory=True, sampler=None, drop_last=False
)

return train_dataloader, test_dataloader
return (train_dataloader, test_dataloader, total_images, num_classes) if with_info else (train_dataloader, test_dataloader)

0 comments on commit 8690a13

Please sign in to comment.