Skip to content

Commit

Permalink
add a function download_buildin_dataset for scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Dec 23, 2023
1 parent d2e7d53 commit 41b4006
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 16 deletions.
5 changes: 1 addition & 4 deletions clip_train_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,7 @@ def parse_arguments():
args = parse_arguments()

if args.data_path in BUILDIN_DATASETS and not os.path.exists(args.data_path):
url, dataset_file = BUILDIN_DATASETS[args.data_path]["url"], BUILDIN_DATASETS[args.data_path]["dataset_file"]
file_path = kecam.backend.get_file(origin=url, cache_subdir="datasets", extract=True) # returned tar file path
args.data_path = os.path.join(os.path.dirname(file_path), args.data_path, dataset_file)
print(">>>> Buildin dataset, path:", args.data_path)
args.data_path = kecam.backend.download_buildin_dataset(args.data_path, BUILDIN_DATASETS, cache_subdir="datasets")

caption_tokenizer = getattr(kecam.clip, args.tokenizer)() if hasattr(kecam.clip, args.tokenizer) else kecam.clip.TikToken(args.tokenizer)
train_dataset, test_dataset = data.init_dataset(
Expand Down
9 changes: 3 additions & 6 deletions coco_train_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,10 @@ def run_training_by_args(args):
batch_size = args.batch_size * strategy.num_replicas_in_sync
input_shape = (args.input_shape, args.input_shape, 3)

if args.data_name in BUILDIN_DATASETS and not os.path.exists(args.data_name):
from keras_cv_attention_models.backend import get_file
if args.data_name in BUILDIN_DATASETS:
from keras_cv_attention_models.backend import download_buildin_dataset

url, dataset_file = BUILDIN_DATASETS[args.data_name]["url"], BUILDIN_DATASETS[args.data_name]["dataset_file"]
file_path = get_file(origin=url, cache_subdir="datasets", extract=True) # returned tar file path
args.data_name = os.path.join(os.path.dirname(file_path), args.data_name, dataset_file)
print(">>>> Buildin dataset, path:", args.data_name)
args.data_name = download_buildin_dataset(args.data_name, BUILDIN_DATASETS, cache_subdir="datasets")

# Init model first, for getting actual pyramid_levels
total_images, num_classes, steps_per_epoch = data.init_dataset(args.data_name, batch_size=batch_size, info_only=True)
Expand Down
7 changes: 1 addition & 6 deletions ddpm_train_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,7 @@ def parse_arguments():
print(">>>> args:", args)

if args.data_path in BUILDIN_DATASETS and not os.path.exists(args.data_path):
url, dataset_file = BUILDIN_DATASETS[args.data_path]["url"], BUILDIN_DATASETS[args.data_path]["dataset_file"]
file_path = os.path.join(os.path.expanduser("~"), ".keras", "datasets", args.data_path)
if not os.path.exists(file_path):
file_path = kecam.backend.get_file(origin=url, cache_subdir="datasets", extract=True) # returned tar file path
args.data_path = os.path.join(os.path.dirname(file_path), args.data_path, dataset_file)
print(">>>> Buildin dataset, path:", args.data_path)
args.data_path = kecam.backend.download_buildin_dataset(args.data_path, BUILDIN_DATASETS, cache_subdir="datasets")

if args.data_path.endswith(".json"):
all_images, all_labels, num_classes = kecam.stable_diffusion.data.init_from_json(args.data_path)
Expand Down
10 changes: 10 additions & 0 deletions keras_cv_attention_models/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,13 @@ def numpy_image_resize(inputs, target_shape, method="bilinear", antialias=False,
inputs = inputs.transpose([0, 3, 1, 2]) if image_data_format() == "channels_last" else inputs
inputs = inputs if ndims == 4 else (inputs[0] if ndims == 3 else inputs[0, 0])
return inputs


def download_buildin_dataset(data_name, buildin_datasets, cache_subdir="datasets"):
url, dataset_file = buildin_datasets[data_name]["url"], buildin_datasets[data_name]["dataset_file"]
file_path = os.path.join(os.path.expanduser("~"), ".keras", cache_subdir, data_name)
if not os.path.exists(file_path):
file_path = get_file(origin=url, cache_subdir=cache_subdir, extract=True) # returned tar file path
data_name = os.path.join(os.path.dirname(file_path), data_name, dataset_file)
print(">>>> Buildin dataset, path:", data_name)
return data_name
12 changes: 12 additions & 0 deletions train_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
import json
from keras_cv_attention_models.imagenet import data, train_func, losses

BUILDIN_DATASETS = {
"coco_dog_cat": {
"url": "https://github.com/leondgarse/keras_cv_attention_models/releases/download/assets/coco_dog_cat.tar.gz",
"dataset_file": "recognition.json",
},
}


def parse_arguments(argv):
import argparse
Expand Down Expand Up @@ -154,6 +161,11 @@ def run_training_by_args(args):
use_teacher_model = args.teacher_model is not None
teacher_model_input_shape = input_shape if args.teacher_model_input_shape == -1 else (args.teacher_model_input_shape, args.teacher_model_input_shape)

if args.data_name in BUILDIN_DATASETS:
from keras_cv_attention_models.backend import download_buildin_dataset

args.data_name = download_buildin_dataset(args.data_name, BUILDIN_DATASETS, cache_subdir="datasets")

# Init model first, for in case of use_token_label, getting token_label_target_patches
total_images, num_classes, steps_per_epoch, num_channels = data.init_dataset(args.data_name, batch_size=batch_size, info_only=True)
input_shape = (*input_shape, num_channels) # Just in case channel is not 3, like mnist being 1...
Expand Down

0 comments on commit 41b4006

Please sign in to comment.