Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIMv2 Image Classification Annotator #76

Merged
merged 5 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ datadreamer --config <path-to-config>
- `--num_objects_range`: Range of objects in a prompt. Default is 1 to 3.
- `--prompt_generator`: Choose between `simple`, `lm` (Mistral-7B), `tiny` (tiny LM), and `qwen2` (Qwen2.5 LM). Default is `qwen2`.
- `--image_generator`: Choose image generator, e.g., `sdxl`, `sdxl-turbo` or `sdxl-lightning`. Default is `sdxl-turbo`.
- `--image_annotator`: Specify the image annotator, like `owlv2` for object detection or `clip` for image classification or `owlv2-slimsam` for instance segmentation. Default is `owlv2`.
- `--image_annotator`: Specify the image annotator, like `owlv2` for object detection or `aimv2` or `clip` for image classification or `owlv2-slimsam` for instance segmentation. Default is `owlv2`.
- `--conf_threshold`: Confidence threshold for annotation. Default is `0.15`.
- `--annotation_iou_threshold`: Intersection over Union (IoU) threshold for annotation. Default is `0.2`.
- `--prompt_prefix`: Prefix to add to every image generation prompt. Default is `""`.
Expand Down Expand Up @@ -218,6 +218,7 @@ datadreamer --config <path-to-config>
| | [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning) | Fast and accurate (1024x1024 images) |
| Image Annotation | [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Vocabulary object detector |
| | [CLIP](https://huggingface.co/openai/clip-vit-base-patch32) | Zero-shot-image-classification |
| | [AIMv2](https://huggingface.co/apple/aimv2-large-patch14-224-lit) | Zero-shot-image-classification |
| | [SlimSAM](https://huggingface.co/Zigeng/SlimSAM-uniform-50) | Zero-shot-instance-segmentation |

<a name="example"></a>
Expand Down
4 changes: 4 additions & 0 deletions datadreamer/dataset_annotation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from __future__ import annotations

from .aimv2_annotator import AIMv2Annotator
from .clip_annotator import CLIPAnnotator
from .cls_annotator import ImgClassificationAnnotator
from .image_annotator import BaseAnnotator, TaskList
from .owlv2_annotator import OWLv2Annotator
from .slimsam_annotator import SlimSAMAnnotator

__all__ = [
"AIMv2Annotator",
"BaseAnnotator",
"TaskList",
"OWLv2Annotator",
"ImgClassificationAnnotator",
"CLIPAnnotator",
"SlimSAMAnnotator",
]
68 changes: 68 additions & 0 deletions datadreamer/dataset_annotation/aimv2_annotator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""This file uses pre-trained model derived from Apple's software, provided under the
Apple Sample Code License license. The license is available at:

https://developer.apple.com/support/downloads/terms/apple-sample-code/Apple-Sample-Code-License.pdf

In addition, this file and other parts of the repository are licensed under the Apache 2.0
License. By using this file, you agree to comply with the terms of both licenses.
"""
from __future__ import annotations

import logging

import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor

from datadreamer.dataset_annotation.cls_annotator import ImgClassificationAnnotator

logger = logging.getLogger(__name__)


class AIMv2Annotator(ImgClassificationAnnotator):
"""A class for image annotation using the AIMv2 model, specializing in image
classification.

Attributes:
model (AutoModel): The AIMv2 model for image-text similarity evaluation.
processor (AutoProcessor): The processor for preparing inputs to the AIMv2 model.
device (str): The device on which the model will run ('cuda' for GPU, 'cpu' for CPU).
size (str): The size of the AIMv2 model to use ('base' or 'large').

Methods:
_init_processor(): Initializes the AIMv2 processor.
_init_model(): Initializes the AIMv2 model.
annotate_batch(image, prompts, conf_threshold, use_tta, synonym_dict): Annotates the given image with bounding boxes and labels.
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache.
"""

def _init_processor(self) -> AutoProcessor:
"""Initializes the AIMv2 processor.

Returns:
AutoProcessor: The initialized AIMv2 processor.
"""
return AutoProcessor.from_pretrained("apple/aimv2-large-patch14-224-lit")

def _init_model(self) -> AutoModel:
"""Initializes the AIMv2 model.

Returns:
AutoModel: The initialized AIMv2 model.
"""
logger.info(f"Initializing AIMv2 {self.size} model...")
return AutoModel.from_pretrained(
"apple/aimv2-large-patch14-224-lit", trust_remote_code=True
)


if __name__ == "__main__":
import requests

device = "cuda" if torch.cuda.is_available() else "cpu"
url = "https://ultralytics.com/images/bus.jpg"
im = Image.open(requests.get(url, stream=True).raw)
annotator = AIMv2Annotator(device=device)
labels = annotator.annotate_batch([im], ["bus", "people"])
print(labels)
annotator.release()
102 changes: 2 additions & 100 deletions datadreamer/dataset_annotation/clip_annotator.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from __future__ import annotations

import logging
from typing import Dict, List

import numpy as np
import PIL
import torch
from PIL import Image
from transformers import CLIPModel, CLIPProcessor

from datadreamer.dataset_annotation.image_annotator import BaseAnnotator, TaskList
from datadreamer.dataset_annotation.cls_annotator import ImgClassificationAnnotator

logger = logging.getLogger(__name__)


class CLIPAnnotator(BaseAnnotator):
class CLIPAnnotator(ImgClassificationAnnotator):
"""A class for image annotation using the CLIP model, specializing in image
classification.

Expand All @@ -31,25 +28,6 @@ class CLIPAnnotator(BaseAnnotator):
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache.
"""

def __init__(
self,
seed: float = 42,
device: str = "cuda",
size: str = "base",
) -> None:
"""Initializes the CLIPAnnotator with a specific seed and device.

Args:
seed (float): Seed for reproducibility. Defaults to 42.
device (str): The device to run the model on. Defaults to 'cuda'.
"""
super().__init__(seed, task_definition=TaskList.CLASSIFICATION)
self.size = size
self.model = self._init_model()
self.processor = self._init_processor()
self.device = device
self.model.to(self.device)

def _init_processor(self) -> CLIPProcessor:
"""Initializes the CLIP processor.

Expand All @@ -71,82 +49,6 @@ def _init_model(self) -> CLIPModel:
return CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
return CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

def annotate_batch(
self,
images: List[PIL.Image.Image],
objects: List[str],
conf_threshold: float = 0.1,
synonym_dict: Dict[str, List[str]] | None = None,
) -> List[np.ndarray]:
"""Annotates images using the OWLv2 model.

Args:
images: The images to be annotated.
objects: A list of objects (text) to test against the images.
conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1.
synonym_dict (dict, optional): Dictionary for handling synonyms in labels. Defaults to None.

Returns:
List[np.ndarray]: A list of the annotations for each image.
"""
if synonym_dict is not None:
objs_syn = set()
for obj in objects:
objs_syn.add(obj)
for syn in synonym_dict[obj]:
objs_syn.add(syn)
objs_syn = list(objs_syn)
# Make a dict to transform synonym ids to original ids
synonym_dict_rev = {}
for key, value in synonym_dict.items():
if key in objects:
synonym_dict_rev[objs_syn.index(key)] = objects.index(key)
for v in value:
synonym_dict_rev[objs_syn.index(v)] = objects.index(key)
objects = objs_syn

inputs = self.processor(
text=objects, images=images, return_tensors="pt", padding=True
).to(self.device)

outputs = self.model(**inputs)

logits_per_image = outputs.logits_per_image # image-text similarity score
probs = logits_per_image.softmax(dim=1).cpu() # label probabilities

labels = []
# Get the labels for each image
if synonym_dict is not None:
for prob in probs:
labels.append(
np.unique(
np.array(
[
synonym_dict_rev[label.item()]
for label in torch.where(prob > conf_threshold)[
0
].numpy()
]
)
)
)
else:
for prob in probs:
labels.append(torch.where(prob > conf_threshold)[0].numpy())

return labels

def release(self, empty_cuda_cache: bool = False) -> None:
"""Releases the model and optionally empties the CUDA cache.

Args:
empty_cuda_cache (bool, optional): Whether to empty the CUDA cache. Defaults to False.
"""
self.model = self.model.to("cpu")
if empty_cuda_cache:
with torch.no_grad():
torch.cuda.empty_cache()


if __name__ == "__main__":
import requests
Expand Down
130 changes: 130 additions & 0 deletions datadreamer/dataset_annotation/cls_annotator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from __future__ import annotations

import logging
from typing import Dict, List

import numpy as np
import PIL
import torch

from datadreamer.dataset_annotation.image_annotator import BaseAnnotator, TaskList

logger = logging.getLogger(__name__)


class ImgClassificationAnnotator(BaseAnnotator):
"""Base class for image classification annotators using transformers models.

Attributes:
model: The model for image-text similarity evaluation.
processor: The processor for preparing inputs to the model.
device (str): The device on which the model will run ('cuda' for GPU, 'cpu' for CPU).
size (str): The size of the model to use ('base' or 'large').

Methods:
_init_processor(): Initializes the processor.
_init_model(): Initializes the model.
annotate_batch(image, prompts, conf_threshold, use_tta, synonym_dict): Annotates the given image with bounding boxes and labels.
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache.
"""

def __init__(
self, seed: float = 42, device: str = "cuda", size: str = "base"
) -> None:
"""Initializes the image classification annotator.

Args:
seed (float): Seed for reproducibility. Defaults to 42.
device (str): The device to run the model on. Defaults to 'cuda'.
size (str): The model size to use.
"""
super().__init__(seed, task_definition=TaskList.CLASSIFICATION)
self.size = size
self.device = device
self.model = self._init_model()
self.processor = self._init_processor()
self.model.to(self.device)

def _init_processor(self):
"""Initializes the processor."""
raise NotImplementedError

def _init_model(self):
"""Initializes the model."""
raise NotImplementedError

def annotate_batch(
self,
images: List[PIL.Image.Image],
objects: List[str],
conf_threshold: float = 0.1,
synonym_dict: Dict[str, List[str]] | None = None,
) -> List[np.ndarray]:
"""Annotates images using the CLIP model.

Args:
images: The images to be annotated.
objects: A list of objects (text) to test against the images.
conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1.
synonym_dict (dict, optional): Dictionary for handling synonyms in labels. Defaults to None.

Returns:
List[np.ndarray]: A list of the annotations for each image.
"""
if synonym_dict is not None:
objs_syn = set()
for obj in objects:
objs_syn.add(obj)
for syn in synonym_dict[obj]:
objs_syn.add(syn)
objs_syn = list(objs_syn)
# Make a dict to transform synonym ids to original ids
synonym_dict_rev = {}
for key, value in synonym_dict.items():
if key in objects:
synonym_dict_rev[objs_syn.index(key)] = objects.index(key)
for v in value:
synonym_dict_rev[objs_syn.index(v)] = objects.index(key)
objects = objs_syn

inputs = self.processor(
text=objects, images=images, return_tensors="pt", padding=True
).to(self.device)

outputs = self.model(**inputs)

logits_per_image = outputs.logits_per_image # image-text similarity score
probs = logits_per_image.softmax(dim=1).cpu() # label probabilities

labels = []
# Get the labels for each image
if synonym_dict is not None:
for prob in probs:
labels.append(
np.unique(
np.array(
[
synonym_dict_rev[label.item()]
for label in torch.where(prob > conf_threshold)[
0
].numpy()
]
)
)
)
else:
for prob in probs:
labels.append(torch.where(prob > conf_threshold)[0].numpy())

return labels

def release(self, empty_cuda_cache: bool = False) -> None:
"""Releases the model and optionally empties the CUDA cache.

Args:
empty_cuda_cache (bool, optional): Whether to empty the CUDA cache. Defaults to False.
"""
self.model = self.model.to("cpu")
if empty_cuda_cache:
with torch.no_grad():
torch.cuda.empty_cache()
5 changes: 3 additions & 2 deletions datadreamer/pipelines/generate_dataset_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tqdm import tqdm

from datadreamer.dataset_annotation import (
AIMv2Annotator,
CLIPAnnotator,
OWLv2Annotator,
SlimSAMAnnotator,
Expand Down Expand Up @@ -57,7 +58,7 @@
}

det_annotators = {"owlv2": OWLv2Annotator}
clf_annotators = {"clip": CLIPAnnotator}
clf_annotators = {"clip": CLIPAnnotator, "aimv2": AIMv2Annotator}
inst_seg_annotators = {"owlv2-slimsam": SlimSAMAnnotator}
inst_seg_detectors = {"owlv2-slimsam": OWLv2Annotator}

Expand Down Expand Up @@ -122,7 +123,7 @@ def parse_args():
parser.add_argument(
"--image_annotator",
type=str,
choices=["owlv2", "clip", "owlv2-slimsam"],
choices=["owlv2", "clip", "owlv2-slimsam", "aimv2"],
help="Image annotator to use",
)

Expand Down
Loading