Skip to content

Commit

Permalink
fix .gitignore and upload missing dir
Browse files Browse the repository at this point in the history
  • Loading branch information
hzphzp committed Aug 29, 2023
1 parent d7b21c4 commit fd0402b
Show file tree
Hide file tree
Showing 37 changed files with 10,022 additions and 1 deletion.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ cython_debug/
.vscode/
abc/
xyz/
data/
**/_backup_/**
**/exp/**
**/ckpts/**
Expand Down
6 changes: 6 additions & 0 deletions Adaptive Frequency Filters/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

from .data_loaders import create_train_val_loader, create_eval_loader
94 changes: 94 additions & 0 deletions Adaptive Frequency Filters/data/collate_fns/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import os
import importlib
import argparse

COLLATE_FN_REGISTRY = {}


def register_collate_fn(name):
def register_collate_fn_method(f):
if name in COLLATE_FN_REGISTRY:
raise ValueError(
"Cannot register duplicate collate function ({})".format(name)
)
COLLATE_FN_REGISTRY[name] = f
return f

return register_collate_fn_method


def arguments_collate_fn(parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="Collate function arguments", description="Collate function arguments"
)
group.add_argument(
"--dataset.collate-fn-name-train",
type=str,
default="default_collate_fn",
help="Name of collate function",
)
group.add_argument(
"--dataset.collate-fn-name-val",
type=str,
default="default_collate_fn",
help="Name of collate function",
)
group.add_argument(
"--dataset.collate-fn-name-eval",
type=str,
default=None,
help="Name of collate function used for evaluation. "
"Default is None, i.e., use PyTorch's inbuilt collate function",
)
return parser


def build_collate_fn(opts, *args, **kwargs):
collate_fn_name_train = getattr(
opts, "dataset.collate_fn_name_train", "default_collate_fn"
)
collate_fn_name_val = getattr(
opts, "dataset.collate_fn_name_val", "default_collate_fn"
)
collate_fn_train = None
if (
collate_fn_name_train is not None
and collate_fn_name_train in COLLATE_FN_REGISTRY
):
collate_fn_train = COLLATE_FN_REGISTRY[collate_fn_name_train]

collate_fn_val = None
if collate_fn_name_val is None:
collate_fn_val = collate_fn_name_train
elif collate_fn_name_val is not None and collate_fn_name_val in COLLATE_FN_REGISTRY:
collate_fn_val = COLLATE_FN_REGISTRY[collate_fn_name_val]

return collate_fn_train, collate_fn_val


def build_eval_collate_fn(opts, *args, **kwargs):
collate_fn_name_eval = getattr(opts, "dataset.collate_fn_name_eval", None)
collate_fn_eval = None
if collate_fn_name_eval is not None and collate_fn_name_eval in COLLATE_FN_REGISTRY:
collate_fn_eval = COLLATE_FN_REGISTRY[collate_fn_name_eval]

return collate_fn_eval


# automatically import the augmentations
collate_fn_dir = os.path.dirname(__file__)

for file in os.listdir(collate_fn_dir):
path = os.path.join(collate_fn_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
collate_fn_fname = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("data.collate_fns." + collate_fn_fname)
42 changes: 42 additions & 0 deletions Adaptive Frequency Filters/data/collate_fns/collate_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import numpy as np
import torch
from typing import List, Dict

from utils import logger

from . import register_collate_fn


@register_collate_fn(name="default_collate_fn")
def default_collate_fn(batch: List[Dict], opts):
"""Default collate function"""
batch_size = len(batch)

keys = list(batch[0].keys())

new_batch = {k: [] for k in keys}
for b in range(batch_size):
for k in keys:
new_batch[k].append(batch[b][k])

# stack the keys
for k in keys:
batch_elements = new_batch.pop(k)

if isinstance(batch_elements[0], (int, float, np.integer, np.floating)):
# list of ints or floats
batch_elements = torch.as_tensor(batch_elements)
else:
# stack tensors (including 0-dimensional)
try:
batch_elements = torch.stack(batch_elements, dim=0).contiguous()
except Exception as e:
logger.error("Unable to stack the tensors. Error: {}".format(e))

new_batch[k] = batch_elements

return new_batch
137 changes: 137 additions & 0 deletions Adaptive Frequency Filters/data/data_loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import torch
from functools import partial

from utils import logger
from utils.ddp_utils import is_master
from utils.tensor_utils import image_size_from_opts

from .datasets import train_val_datasets, evaluation_datasets
from .sampler import build_sampler
from .collate_fns import build_collate_fn, build_eval_collate_fn
from .loader.dataloader import affnetDataLoader


def create_eval_loader(opts):
eval_dataset = evaluation_datasets(opts)
n_eval_samples = len(eval_dataset)
is_master_node = is_master(opts)

# overwrite the validation argument
setattr(
opts, "dataset.val_batch_size0", getattr(opts, "dataset.eval_batch_size0", 1)
)

# we don't need variable batch sampler for evaluation
sampler_name = getattr(opts, "sampler.name", "batch_sampler")
crop_size_h, crop_size_w = image_size_from_opts(opts)
if sampler_name.find("video") > -1 and sampler_name != "video_batch_sampler":
clips_per_video = getattr(opts, "sampler.vbs.clips_per_video", 1)
frames_per_clip = getattr(opts, "sampler.vbs.num_frames_per_clip", 8)
setattr(opts, "sampler.name", "video_batch_sampler")
setattr(opts, "sampler.bs.crop_size_width", crop_size_w)
setattr(opts, "sampler.bs.crop_size_height", crop_size_h)
setattr(opts, "sampler.bs.clips_per_video", clips_per_video)
setattr(opts, "sampler.bs.num_frames_per_clip", frames_per_clip)
elif sampler_name.find("var") > -1:
setattr(opts, "sampler.name", "batch_sampler")
setattr(opts, "sampler.bs.crop_size_width", crop_size_w)
setattr(opts, "sampler.bs.crop_size_height", crop_size_h)

eval_sampler = build_sampler(
opts=opts, n_data_samples=n_eval_samples, is_training=False
)

collate_fn_eval = build_eval_collate_fn(opts=opts)

data_workers = getattr(opts, "dataset.workers", 1)
persistent_workers = False
pin_memory = False

eval_loader = affnetDataLoader(
dataset=eval_dataset,
batch_size=1,
batch_sampler=eval_sampler,
num_workers=data_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
collate_fn=partial(collate_fn_eval, opts=opts)
if collate_fn_eval is not None
else None,
)

if is_master_node:
logger.log("Evaluation sampler details: ")
print("{}".format(eval_sampler))

return eval_loader


def create_train_val_loader(opts):
train_dataset, valid_dataset = train_val_datasets(opts)

n_train_samples = len(train_dataset)
is_master_node = is_master(opts)

train_sampler = build_sampler(
opts=opts, n_data_samples=n_train_samples, is_training=True
)
if valid_dataset is not None:
n_valid_samples = len(valid_dataset)
valid_sampler = build_sampler(
opts=opts, n_data_samples=n_valid_samples, is_training=False
)
else:
valid_sampler = None

data_workers = getattr(opts, "dataset.workers", 1)
persistent_workers = getattr(opts, "dataset.persistent_workers", False) and (
data_workers > 0
)
pin_memory = getattr(opts, "dataset.pin_memory", False)
prefetch_factor = getattr(opts, "dataset.prefetch_factor", 2)

collate_fn_train, collate_fn_val = build_collate_fn(opts=opts)

train_loader = affnetDataLoader(
dataset=train_dataset,
batch_size=1, # Handled inside data sampler
num_workers=data_workers,
pin_memory=pin_memory,
batch_sampler=train_sampler,
persistent_workers=persistent_workers,
collate_fn=partial(collate_fn_train, opts=opts)
if collate_fn_train is not None
else None,
prefetch_factor=prefetch_factor,
)

if valid_dataset is not None:
val_loader = affnetDataLoader(
dataset=valid_dataset,
batch_size=1,
batch_sampler=valid_sampler,
num_workers=data_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
collate_fn=partial(collate_fn_val, opts=opts)
if collate_fn_val is not None
else None,
)
else:
val_loader = None

if is_master_node:
logger.log("Training sampler details: ")
print("{}".format(train_sampler))

if valid_dataset is not None:
logger.log("Validation sampler details: ")
print("{}".format(valid_sampler))
logger.log("Number of data workers: {}".format(data_workers))

return train_loader, val_loader, train_sampler
Loading

0 comments on commit fd0402b

Please sign in to comment.