Skip to content

Commit 5225942

Browse files
Bartosz SmoczynskiJakub Pieszczek
authored andcommitted
Add csv reader
1 parent fb2f2d7 commit 5225942

File tree

3 files changed

+61
-0
lines changed

3 files changed

+61
-0
lines changed

timm/data/readers/reader_factory.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from .reader_image_folder import ReaderImageFolder
55
from .reader_image_in_tar import ReaderImageInTar
6+
from .reader_paths_csv import ReaderPathsCsv
67

78

89
def create_reader(
@@ -34,6 +35,13 @@ def create_reader(
3435
from .reader_wds import ReaderWds
3536
kwargs.pop('download', False)
3637
reader = ReaderWds(root=root, name=name, split=split, **kwargs)
38+
elif "samples_csv_path" in kwargs:
39+
assert "class_map" in kwargs
40+
reader = ReaderPathsCsv(
41+
images_dir=root,
42+
samples_csv_path=kwargs["samples_csv_path"],
43+
class_map=kwargs["class_map"],
44+
)
3745
else:
3846
assert os.path.exists(root)
3947
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
A dataset reader that extracts images from a single folder
3+
based on a csv with labels and filenames relative to that folder.
4+
"""
5+
import os
6+
import pandas as pd
7+
8+
from .reader import Reader
9+
10+
11+
class ReaderPathsCsv(Reader):
12+
def __init__(
13+
self,
14+
images_dir,
15+
samples_csv_path,
16+
class_map: dict[str, int],
17+
):
18+
super().__init__()
19+
assert isinstance(class_map, dict)
20+
21+
self.images_dir = images_dir
22+
samples_df = pd.read_csv(samples_csv_path).astype(str)
23+
24+
if not samples_df["label"].isin(class_map).all():
25+
unrecognized_ids = ~samples_df["label"].isin(class_map)
26+
unrecognized_labels = set(samples_df.loc[unrecognized_ids, "label"])
27+
raise ValueError(f"Unrecognized labels found in samples_df: {unrecognized_labels}")
28+
29+
samples_df["label"] = samples_df["label"].map(class_map)
30+
31+
self.samples_df = samples_df
32+
33+
def __getitem__(self, index):
34+
filename, target = self.samples_df.iloc[index]
35+
path = os.path.join(self.images_dir, filename)
36+
return open(path, 'rb'), target
37+
38+
def __len__(self):
39+
return len(self.samples_df.index)
40+
41+
def _filename(self, index, basename=False, absolute=False):
42+
filename = self.samples_df.iloc[index, "filename"]
43+
if basename:
44+
filename = os.path.basename(filename)
45+
elif not absolute:
46+
filename = os.path.relpath(filename, self.images_dir)
47+
return filename

timm/train.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@
8888
help='path to dataset (root dir)')
8989
group.add_argument('--dataset', metavar='NAME', default='',
9090
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
91+
parser.add_argument('--train-samples-csv-path', metavar='PATH',
92+
help='path to csv with train filenames and labels')
93+
parser.add_argument('--val-samples-csv-path', metavar='PATH',
94+
help='path to csv with train filenames and labels')
9195
group.add_argument('--train-split', metavar='NAME', default='train',
9296
help='dataset train split (default: train)')
9397
group.add_argument('--val-split', metavar='NAME', default='validation',
@@ -685,6 +689,7 @@ def train(config: dict[str, t.Any]):
685689
target_key=args.target_key,
686690
num_samples=args.train_num_samples,
687691
trust_remote_code=args.dataset_trust_remote_code,
692+
samples_csv_path=args.train_samples_csv_path,
688693
)
689694

690695
if args.val_split:
@@ -701,6 +706,7 @@ def train(config: dict[str, t.Any]):
701706
target_key=args.target_key,
702707
num_samples=args.val_num_samples,
703708
trust_remote_code=args.dataset_trust_remote_code,
709+
samples_csv_path=args.val_samples_csv_path,
704710
)
705711

706712
# setup mixup / cutmix

0 commit comments

Comments
 (0)