-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix .gitignore and upload missing dir
- Loading branch information
Showing
37 changed files
with
10,022 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -143,7 +143,6 @@ cython_debug/ | |
.vscode/ | ||
abc/ | ||
xyz/ | ||
data/ | ||
**/_backup_/** | ||
**/exp/** | ||
**/ckpts/** | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# | ||
# For licensing see accompanying LICENSE file. | ||
# Copyright (C) 2023 Apple Inc. All Rights Reserved. | ||
# | ||
|
||
from .data_loaders import create_train_val_loader, create_eval_loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# | ||
# For licensing see accompanying LICENSE file. | ||
# Copyright (C) 2023 Apple Inc. All Rights Reserved. | ||
# | ||
|
||
import os | ||
import importlib | ||
import argparse | ||
|
||
COLLATE_FN_REGISTRY = {} | ||
|
||
|
||
def register_collate_fn(name): | ||
def register_collate_fn_method(f): | ||
if name in COLLATE_FN_REGISTRY: | ||
raise ValueError( | ||
"Cannot register duplicate collate function ({})".format(name) | ||
) | ||
COLLATE_FN_REGISTRY[name] = f | ||
return f | ||
|
||
return register_collate_fn_method | ||
|
||
|
||
def arguments_collate_fn(parser: argparse.ArgumentParser): | ||
group = parser.add_argument_group( | ||
title="Collate function arguments", description="Collate function arguments" | ||
) | ||
group.add_argument( | ||
"--dataset.collate-fn-name-train", | ||
type=str, | ||
default="default_collate_fn", | ||
help="Name of collate function", | ||
) | ||
group.add_argument( | ||
"--dataset.collate-fn-name-val", | ||
type=str, | ||
default="default_collate_fn", | ||
help="Name of collate function", | ||
) | ||
group.add_argument( | ||
"--dataset.collate-fn-name-eval", | ||
type=str, | ||
default=None, | ||
help="Name of collate function used for evaluation. " | ||
"Default is None, i.e., use PyTorch's inbuilt collate function", | ||
) | ||
return parser | ||
|
||
|
||
def build_collate_fn(opts, *args, **kwargs): | ||
collate_fn_name_train = getattr( | ||
opts, "dataset.collate_fn_name_train", "default_collate_fn" | ||
) | ||
collate_fn_name_val = getattr( | ||
opts, "dataset.collate_fn_name_val", "default_collate_fn" | ||
) | ||
collate_fn_train = None | ||
if ( | ||
collate_fn_name_train is not None | ||
and collate_fn_name_train in COLLATE_FN_REGISTRY | ||
): | ||
collate_fn_train = COLLATE_FN_REGISTRY[collate_fn_name_train] | ||
|
||
collate_fn_val = None | ||
if collate_fn_name_val is None: | ||
collate_fn_val = collate_fn_name_train | ||
elif collate_fn_name_val is not None and collate_fn_name_val in COLLATE_FN_REGISTRY: | ||
collate_fn_val = COLLATE_FN_REGISTRY[collate_fn_name_val] | ||
|
||
return collate_fn_train, collate_fn_val | ||
|
||
|
||
def build_eval_collate_fn(opts, *args, **kwargs): | ||
collate_fn_name_eval = getattr(opts, "dataset.collate_fn_name_eval", None) | ||
collate_fn_eval = None | ||
if collate_fn_name_eval is not None and collate_fn_name_eval in COLLATE_FN_REGISTRY: | ||
collate_fn_eval = COLLATE_FN_REGISTRY[collate_fn_name_eval] | ||
|
||
return collate_fn_eval | ||
|
||
|
||
# automatically import the augmentations | ||
collate_fn_dir = os.path.dirname(__file__) | ||
|
||
for file in os.listdir(collate_fn_dir): | ||
path = os.path.join(collate_fn_dir, file) | ||
if ( | ||
not file.startswith("_") | ||
and not file.startswith(".") | ||
and (file.endswith(".py") or os.path.isdir(path)) | ||
): | ||
collate_fn_fname = file[: file.find(".py")] if file.endswith(".py") else file | ||
module = importlib.import_module("data.collate_fns." + collate_fn_fname) |
42 changes: 42 additions & 0 deletions
42
Adaptive Frequency Filters/data/collate_fns/collate_functions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# | ||
# For licensing see accompanying LICENSE file. | ||
# Copyright (C) 2023 Apple Inc. All Rights Reserved. | ||
# | ||
import numpy as np | ||
import torch | ||
from typing import List, Dict | ||
|
||
from utils import logger | ||
|
||
from . import register_collate_fn | ||
|
||
|
||
@register_collate_fn(name="default_collate_fn") | ||
def default_collate_fn(batch: List[Dict], opts): | ||
"""Default collate function""" | ||
batch_size = len(batch) | ||
|
||
keys = list(batch[0].keys()) | ||
|
||
new_batch = {k: [] for k in keys} | ||
for b in range(batch_size): | ||
for k in keys: | ||
new_batch[k].append(batch[b][k]) | ||
|
||
# stack the keys | ||
for k in keys: | ||
batch_elements = new_batch.pop(k) | ||
|
||
if isinstance(batch_elements[0], (int, float, np.integer, np.floating)): | ||
# list of ints or floats | ||
batch_elements = torch.as_tensor(batch_elements) | ||
else: | ||
# stack tensors (including 0-dimensional) | ||
try: | ||
batch_elements = torch.stack(batch_elements, dim=0).contiguous() | ||
except Exception as e: | ||
logger.error("Unable to stack the tensors. Error: {}".format(e)) | ||
|
||
new_batch[k] = batch_elements | ||
|
||
return new_batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# | ||
# For licensing see accompanying LICENSE file. | ||
# Copyright (C) 2023 Apple Inc. All Rights Reserved. | ||
# | ||
|
||
import torch | ||
from functools import partial | ||
|
||
from utils import logger | ||
from utils.ddp_utils import is_master | ||
from utils.tensor_utils import image_size_from_opts | ||
|
||
from .datasets import train_val_datasets, evaluation_datasets | ||
from .sampler import build_sampler | ||
from .collate_fns import build_collate_fn, build_eval_collate_fn | ||
from .loader.dataloader import affnetDataLoader | ||
|
||
|
||
def create_eval_loader(opts): | ||
eval_dataset = evaluation_datasets(opts) | ||
n_eval_samples = len(eval_dataset) | ||
is_master_node = is_master(opts) | ||
|
||
# overwrite the validation argument | ||
setattr( | ||
opts, "dataset.val_batch_size0", getattr(opts, "dataset.eval_batch_size0", 1) | ||
) | ||
|
||
# we don't need variable batch sampler for evaluation | ||
sampler_name = getattr(opts, "sampler.name", "batch_sampler") | ||
crop_size_h, crop_size_w = image_size_from_opts(opts) | ||
if sampler_name.find("video") > -1 and sampler_name != "video_batch_sampler": | ||
clips_per_video = getattr(opts, "sampler.vbs.clips_per_video", 1) | ||
frames_per_clip = getattr(opts, "sampler.vbs.num_frames_per_clip", 8) | ||
setattr(opts, "sampler.name", "video_batch_sampler") | ||
setattr(opts, "sampler.bs.crop_size_width", crop_size_w) | ||
setattr(opts, "sampler.bs.crop_size_height", crop_size_h) | ||
setattr(opts, "sampler.bs.clips_per_video", clips_per_video) | ||
setattr(opts, "sampler.bs.num_frames_per_clip", frames_per_clip) | ||
elif sampler_name.find("var") > -1: | ||
setattr(opts, "sampler.name", "batch_sampler") | ||
setattr(opts, "sampler.bs.crop_size_width", crop_size_w) | ||
setattr(opts, "sampler.bs.crop_size_height", crop_size_h) | ||
|
||
eval_sampler = build_sampler( | ||
opts=opts, n_data_samples=n_eval_samples, is_training=False | ||
) | ||
|
||
collate_fn_eval = build_eval_collate_fn(opts=opts) | ||
|
||
data_workers = getattr(opts, "dataset.workers", 1) | ||
persistent_workers = False | ||
pin_memory = False | ||
|
||
eval_loader = affnetDataLoader( | ||
dataset=eval_dataset, | ||
batch_size=1, | ||
batch_sampler=eval_sampler, | ||
num_workers=data_workers, | ||
pin_memory=pin_memory, | ||
persistent_workers=persistent_workers, | ||
collate_fn=partial(collate_fn_eval, opts=opts) | ||
if collate_fn_eval is not None | ||
else None, | ||
) | ||
|
||
if is_master_node: | ||
logger.log("Evaluation sampler details: ") | ||
print("{}".format(eval_sampler)) | ||
|
||
return eval_loader | ||
|
||
|
||
def create_train_val_loader(opts): | ||
train_dataset, valid_dataset = train_val_datasets(opts) | ||
|
||
n_train_samples = len(train_dataset) | ||
is_master_node = is_master(opts) | ||
|
||
train_sampler = build_sampler( | ||
opts=opts, n_data_samples=n_train_samples, is_training=True | ||
) | ||
if valid_dataset is not None: | ||
n_valid_samples = len(valid_dataset) | ||
valid_sampler = build_sampler( | ||
opts=opts, n_data_samples=n_valid_samples, is_training=False | ||
) | ||
else: | ||
valid_sampler = None | ||
|
||
data_workers = getattr(opts, "dataset.workers", 1) | ||
persistent_workers = getattr(opts, "dataset.persistent_workers", False) and ( | ||
data_workers > 0 | ||
) | ||
pin_memory = getattr(opts, "dataset.pin_memory", False) | ||
prefetch_factor = getattr(opts, "dataset.prefetch_factor", 2) | ||
|
||
collate_fn_train, collate_fn_val = build_collate_fn(opts=opts) | ||
|
||
train_loader = affnetDataLoader( | ||
dataset=train_dataset, | ||
batch_size=1, # Handled inside data sampler | ||
num_workers=data_workers, | ||
pin_memory=pin_memory, | ||
batch_sampler=train_sampler, | ||
persistent_workers=persistent_workers, | ||
collate_fn=partial(collate_fn_train, opts=opts) | ||
if collate_fn_train is not None | ||
else None, | ||
prefetch_factor=prefetch_factor, | ||
) | ||
|
||
if valid_dataset is not None: | ||
val_loader = affnetDataLoader( | ||
dataset=valid_dataset, | ||
batch_size=1, | ||
batch_sampler=valid_sampler, | ||
num_workers=data_workers, | ||
pin_memory=pin_memory, | ||
persistent_workers=persistent_workers, | ||
collate_fn=partial(collate_fn_val, opts=opts) | ||
if collate_fn_val is not None | ||
else None, | ||
) | ||
else: | ||
val_loader = None | ||
|
||
if is_master_node: | ||
logger.log("Training sampler details: ") | ||
print("{}".format(train_sampler)) | ||
|
||
if valid_dataset is not None: | ||
logger.log("Validation sampler details: ") | ||
print("{}".format(valid_sampler)) | ||
logger.log("Number of data workers: {}".format(data_workers)) | ||
|
||
return train_loader, val_loader, train_sampler |
Oops, something went wrong.