Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
79ad788
Add YOLO dataset format support
mario-dg Mar 28, 2025
9d6c8d4
Rename is_correct_yolo_format method
mario-dg Mar 28, 2025
c690c93
Add YOLO format data loader, ensure COCO API compliance
mario-dg Mar 29, 2025
b9ae481
Clean up dataloader
mario-dg Mar 29, 2025
8234f66
Undo last changes
mario-dg Mar 29, 2025
fda8d9b
image_ids should start at 1 and class ids were mismatched
mario-dg Mar 31, 2025
67c6191
Cleanup docs of latest changes
mario-dg Mar 31, 2025
d388ef4
Ensure non-string idx
mario-dg Mar 31, 2025
4a98273
Try to fix COCO like API
mario-dg Mar 31, 2025
19c4a18
Use proper dataset structure in COCOLikeAPI
mario-dg Mar 31, 2025
ff4bc6c
Fix image and annotation ID in COCOLikeAPI
mario-dg Mar 31, 2025
89029a0
Consistency between images and annotations in COCOLikeAPI
mario-dg Mar 31, 2025
e92ced3
Remove default list when retrieving class names from data.yaml
mario-dg Mar 31, 2025
834c54e
Code Review improvements
mario-dg Mar 31, 2025
a5884c5
Remove more unnecessary inline comments
mario-dg Mar 31, 2025
54dc467
More code review improvements
mario-dg Mar 31, 2025
40e5273
Fix imports
mario-dg Mar 31, 2025
19adea7
Notify user, if image or label files are skipped
mario-dg Mar 31, 2025
e4b361e
Correct usage of supervision file util methods
mario-dg Mar 31, 2025
d09d468
Fix params of parse_yolo_annotations
mario-dg Mar 31, 2025
b3844f2
Forgot to return boxes and labels
mario-dg Mar 31, 2025
5c9b972
Use constant in build_yolo for consistency
mario-dg Mar 31, 2025
f8e15da
Merge remote-tracking branch 'origin/main' into support-yolo-format
mario-dg Apr 1, 2025
b1aeaad
Merge branch 'develop' into support-yolo-format
mario-dg Apr 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,51 @@ dataset/

[Roboflow](https://roboflow.com/annotate) allows you to create object detection datasets from scratch or convert existing datasets from formats like YOLO, and then export them in COCO JSON format for training. You can also explore [Roboflow Universe](https://universe.roboflow.com/) to find pre-labeled datasets for a range of use cases.

### YOLO Format Support

RF-DETR now also supports training directly on YOLO format datasets. The dataset should follow the standard YOLO format structure:

```
dataset/
├── data.yaml # Contains class names, number of classes, etc.
├── train/
│ ├── images/
│ │ ├── image1.jpg
│ │ ├── image2.jpg
│ │ └── ... (other image files)
│ └── labels/
│ ├── image1.txt
│ ├── image2.txt
│ └── ... (other label files)
├── valid/ # Note: 'valid' is used instead of 'val'
│ ├── images/
│ │ └── ... (image files)
│ └── labels/
│ └── ... (label files)
└── test/ # Optional
├── images/
│ └── ... (image files)
└── labels/
└── ... (label files)
```

Each label file contains annotations in YOLO format: `class_id x_center y_center width height` with normalized coordinates (0-1).

```python
from rfdetr import RFDETRBase

model = RFDETRBase()

model.train(
dataset_dir=<YOLO_DATASET_PATH>,
epochs=10,
batch_size=4,
grad_accum_steps=4,
lr=1e-4,
output_dir=<OUTPUT_PATH>
)
```

### Fine-tuning

You can fine-tune RF-DETR from pre-trained COCO checkpoints. By default, the RF-DETR-B checkpoint will be used. To get started quickly, please refer to our fine-tuning Google Colab [notebook](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-finetune-rf-detr-on-detection-dataset.ipynb).
Expand Down
2 changes: 1 addition & 1 deletion rfdetr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class TrainConfig(BaseModel):
ia_bce_loss: bool = True
cls_loss_coef: float = 1.0
num_select: int = 300
dataset_file: Literal["coco", "o365", "roboflow"] = "roboflow"
dataset_file: Literal["coco", "o365", "roboflow", "yolo"] = "roboflow"
square_resize_div_64: bool = True
dataset_dir: str
output_dir: str = "output"
Expand Down
5 changes: 5 additions & 0 deletions rfdetr/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .coco import build as build_coco
from .o365 import build_o365
from .coco import build_roboflow
from .yolo import build_yolo, YOLODataset


def get_coco_api_from_dataset(dataset):
Expand All @@ -24,6 +25,8 @@ def get_coco_api_from_dataset(dataset):
dataset = dataset.dataset
if isinstance(dataset, torchvision.datasets.CocoDetection):
return dataset.coco
if isinstance(dataset, YOLODataset):
return dataset.coco


def build_dataset(image_set, args, resolution):
Expand All @@ -33,4 +36,6 @@ def build_dataset(image_set, args, resolution):
return build_o365(image_set, args, resolution)
if args.dataset_file == 'roboflow':
return build_roboflow(image_set, args, resolution)
if args.dataset_file == 'yolo':
return build_yolo(image_set, args, resolution)
raise ValueError(f'dataset {args.dataset_file} not supported')
Loading