Skip to content

Commit

Permalink
update for torch coco training getting rid of ultralytics dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Feb 21, 2024
1 parent 415098d commit 98f2da5
Show file tree
Hide file tree
Showing 14 changed files with 110 additions and 162 deletions.
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -417,14 +417,11 @@
CUDA_VISIBLE_DEVICES='1' python3 coco_eval_script.py -m checkpoints/yoloxtiny_yolor_anchor.h5
# >>>> [COCOEvalCallback] input_shape: (416, 416), pyramid_levels: [3, 5], anchors_mode: yolor
```
- **[Experimental] Training using PyTorch backend**, currently using `ultralytics` dataset and validator process. The parameter `rect_val=False` means using fixed data shape `[640, 640]` for validator, or will by dynamic.
- **[Experimental] Training using PyTorch backend**
```py
!pip install ultralytics

import os, sys
os.environ["KECAM_BACKEND"] = "torch"
sys.setrecursionlimit(65536)
# sys.path.append(os.path.expanduser("~/workspace/ultralytics/"))

from keras_cv_attention_models.yolov8 import train, yolov8, torch_wrapper
from keras_cv_attention_models import efficientnet
Expand All @@ -433,8 +430,7 @@
bb = efficientnet.EfficientNetV2B0(input_shape=(3, 640, 640), num_classes=0)
model = yolov8.YOLOV8_N(backbone=bb, classifier_activation=None, pretrained=None).cuda()
# model = yolov8.YOLOV8_N(input_shape=(3, None, None), classifier_activation=None, pretrained=None).cuda()
model = torch_wrapper.Detect(model)
ema = train.train(model, dataset_path="coco.yaml", rect_val=False)
ema = train.train(model, dataset_path="coco.json")
```
![yolov8_training](https://user-images.githubusercontent.com/5744524/235142289-cb6a4da0-1ea7-4261-afdd-03a3c36278b8.png)
## CLIP training and evaluating
Expand Down
1 change: 1 addition & 0 deletions coco_eval_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
except:
pass


def parse_arguments(argv):
import argparse

Expand Down
11 changes: 9 additions & 2 deletions custom_dataset_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,16 @@ def match_detection_labels_coco_annotation(image_names, label_path, target_ids=N

def match_detection_labels(image_names, label_path):
if label_path.endswith(".json"):
return match_detection_labels_coco_annotation(image_names, label_path)
x_train, y_train, indices_2_labels = match_detection_labels_coco_annotation(image_names, label_path)
else:
return match_detection_labels_dir(image_names, label_path)
x_train, y_train, indices_2_labels = match_detection_labels_dir(image_names, label_path)

""" Adding images with no labels as backgrounds """
num_backgrounds = len(image_names) - len(x_train)
print(">>>> Total instances: {}, pure backgrounds: {}".format(len(image_names), num_backgrounds))
x_train.extend(list(set(image_names) - set(x_train)))
y_train.extend([{"label": [], "bbox": []}] * num_backgrounds)
return x_train, y_train, indices_2_labels


def convert_to_corner_by_format(bbox, bbox_source_format="yxyx"):
Expand Down
14 changes: 11 additions & 3 deletions keras_cv_attention_models/coco/eval_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
mode="global", # decode parameter, can be set new value in `self.call`
topk=0, # decode parameter, can be set new value in `self.call`
use_static_output=False, # Set to True if using this as an actual layer, especially for converting tflite
use_sigmoid_on_score=False, # wether applying sigmoid on score outputs. Set True if model is built using `classifier_activation=None`
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -58,7 +59,7 @@ def __init__(
else:
self.anchors = None
self.__input_shape__ = input_shape
self.use_static_output = use_static_output
self.use_static_output, self.use_sigmoid_on_score = use_static_output, use_sigmoid_on_score
self.nms_kwargs = {
"score_threshold": score_threshold,
"iou_or_sigma": iou_or_sigma,
Expand Down Expand Up @@ -163,6 +164,7 @@ def __decode_single__(self, pred, score_threshold=0.3, iou_or_sigma=0.5, max_out
anchors = self.anchors
if self.use_object_scores:
ccs = ccs * object_scores
ccs = functional.sigmoid(ccs) if self.use_sigmoid_on_score else ccs

# print(f"{bbs.shape = }, {anchors.shape = }")
bbs_decoded = anchors_func.decode_bboxes(bbs, anchors, regression_len=self.regression_len)
Expand Down Expand Up @@ -496,7 +498,9 @@ def __init__(
def build(self, input_shape, output_shape):
import re

input_shape = (int(input_shape[1]), int(input_shape[2])) if backend.image_data_format() == "channels_last" else (int(input_shape[2]), int(input_shape[3]))
input_shape = (
(int(input_shape[1]), int(input_shape[2])) if backend.image_data_format() == "channels_last" else (int(input_shape[2]), int(input_shape[3]))
)
self.eval_dataset, self.num_classes = init_eval_dataset(input_shape=input_shape, **self.dataset_kwargs)
print("\n>>>> [COCOEvalCallback] self.dataset_kwargs:", self.dataset_kwargs)
regression_len = (output_shape[-1] - self.num_classes) // 4 * 4
Expand All @@ -512,7 +516,11 @@ def build(self, input_shape, output_shape):
# print(">>>>", self.dataset_kwargs)
# print(">>>>", self.nms_kwargs)

self.pred_decoder = DecodePredictions(input_shape, pyramid_levels, self.anchors_mode, regression_len=regression_len, **self.anchor_kwargs)
use_sigmoid_on_score = not any([ii.name.endswith("_sigmoid") for ii in self.model.layers[-50:]])
print(">>>> use_sigmoid_on_score:", use_sigmoid_on_score)
self.pred_decoder = DecodePredictions(
input_shape, pyramid_levels, self.anchors_mode, regression_len=regression_len, use_sigmoid_on_score=use_sigmoid_on_score, **self.anchor_kwargs
)

# Training saving best
if self.model_basic_save_name is not None:
Expand Down
2 changes: 2 additions & 0 deletions keras_cv_attention_models/coco/tf_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,10 @@ def detection_dataset_from_custom_json(data_path, with_info=False):
base_path = os.path.expanduser(info["base_path"])
for ii in train:
ii["image"] = os.path.join(base_path, ii["image"])
ii["objects"]["bbox"] = tf.reshape(ii["objects"]["bbox"], [-1, 4])
for ii in test:
ii["image"] = os.path.join(base_path, ii["image"])
ii["objects"]["bbox"] = tf.reshape(ii["objects"]["bbox"], [-1, 4])

objects_signature = {"bbox": tf.TensorSpec(shape=(None, 4), dtype=tf.float32), "label": tf.TensorSpec(shape=(None,), dtype=tf.int64)}
output_signature = {"image": tf.TensorSpec(shape=(), dtype=tf.string), "objects": objects_signature}
Expand Down
4 changes: 2 additions & 2 deletions keras_cv_attention_models/coco/torch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def __init__(self, data, image_size=640, batch_size=16, is_train=True, mosaic=1.
datapoint["objects"]["label"] = np.array(datapoint["objects"]["label"], dtype="int64")

import cv2

self.imread = cv2.imread

""" Cache first for using in mosaic mix """
Expand All @@ -209,7 +210,7 @@ def __process_eval__(self, index, image_path, bbox, label):
return image, bbox, label

def __process_train__(self, index, image_path, bbox, label):
""" Cache read images """
"""Cache read images"""
if index in self.cached_image_indexes: # Seldom should this happen
image = self.cached_images[index]
else:
Expand Down Expand Up @@ -245,7 +246,6 @@ def __process_train__(self, index, image_path, bbox, label):
bbox[:, [1, 3]] = 1 - bbox[:, [3, 1]]
return image, bbox, label


def __getitem__(self, index):
datapoint = self.data[index]
image_path, objects = datapoint["image"], datapoint["objects"]
Expand Down
2 changes: 1 addition & 1 deletion keras_cv_attention_models/gpt2/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def build(self, input_shape):

def call(self, inputs):
qq_len, kk_len = (functional.shape(inputs)[2], functional.shape(inputs)[3]) if backend.is_tensorflow_backend else (inputs.shape[2], inputs.shape[3])
return inputs + self.causal_mask[:, :, : qq_len, : kk_len]
return inputs + self.causal_mask[:, :, :qq_len, :kk_len]

def get_config(self):
base_config = super().get_config()
Expand Down
1 change: 1 addition & 0 deletions keras_cv_attention_models/imagenet/eval_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class TFLiteModelInterf:
>>> # >>>> Calling resize_tensor_input, input_shape (1, 224, 224, 3) -> (1, 119, 75, 3):
>>> # tt([np.ones([119, 75, 3]).astype('float32')]).shape = (1, 1000)
"""

def __init__(self, model_path):
import tensorflow as tf

Expand Down
2 changes: 1 addition & 1 deletion keras_cv_attention_models/llama2/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def build(self, input_shape):
def call(self, inputs, **kwargs):
left, right = functional.unstack(inputs, axis=-2)
seq_len = functional.shape(left)[-3] if backend.is_tensorflow_backend else left.shape[-3]
pos_cos, pos_sin = self.pos_cos[: seq_len], self.pos_sin[: seq_len]
pos_cos, pos_sin = self.pos_cos[:seq_len], self.pos_sin[:seq_len]
out = functional.stack([left * pos_cos - right * pos_sin, right * pos_cos + left * pos_sin], axis=-2)
return out

Expand Down
50 changes: 7 additions & 43 deletions keras_cv_attention_models/yolov8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,69 +230,33 @@
# Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.855
```
## Training using PyTorch backend and ultralytics
- **[Experimental] Training using PyTorch backend**, currently using `ultralytics` dataset and validator process. The advantage is that this supports any pyramid staged model in this package.
- The parameter `rect_val=False` means using fixed data shape `[640, 640]` for validator, or will by dynamic.
- **Custom dataset** is created using `custom_dataset_script.py`, which can be used as `dataset_path="coco.json"` for training, detail usage can be found in [Custom detection dataset](https://github.com/leondgarse/keras_cv_attention_models/discussions/52#discussioncomment-2460664).
- **Train using `EfficientNetV2B0` backbone + `YOLOV8_N` head**.
```py
import os, sys
import os, sys, torch
os.environ["KECAM_BACKEND"] = "torch"
sys.path.append(os.path.expanduser("~/workspace/ultralytics/"))

from keras_cv_attention_models.yolov8 import train, yolov8, torch_wrapper
from keras_cv_attention_models import efficientnet

global_device = torch.device("cuda:0") if torch.cuda.is_available() and int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")) >= 0 else torch.device("cpu")
# model Trainable params: 7,023,904, GFLOPs: 8.1815G
bb = efficientnet.EfficientNetV2B0(input_shape=(3, 640, 640), num_classes=0)
model = yolov8.YOLOV8_N(backbone=bb, classifier_activation=None, pretrained=None).cuda()
model = yolov8.YOLOV8_N(backbone=bb, classifier_activation=None, pretrained=None).to(global_device) # Note: classifier_activation=None
# model = yolov8.YOLOV8_N(input_shape=(3, None, None), classifier_activation=None, pretrained=None).cuda()
model = torch_wrapper.Detect(model)
ema = train.train(model, dataset_path="coco.yaml", rect_val=False)
ema = train.train(model, dataset_path="coco.json", initial_epoch=0)
```
![yolov8_training](https://user-images.githubusercontent.com/5744524/235142289-cb6a4da0-1ea7-4261-afdd-03a3c36278b8.png)
- **Predict after training** **Note: currently trained weights output format is `[left, top, right, bottom]`, or `xyxy` format, while eval and show here using `[top, left, bottom, right]`, or `yxyx` format. This may change in the future.**
- **Predict after training using Torch / TF backend**. `bbox` output format is in `[top, left, bottom, right]`, or `yxyx` format.
```py
os.environ["KECAM_BACKEND"] = "torch"
import torch
from keras_cv_attention_models import efficientnet, yolov8, test_images
from keras_cv_attention_models.coco import data
from keras_cv_attention_models.backend import numpy_image_resize, functional
from keras_cv_attention_models.yolov8 import torch_wrapper

bb = efficientnet.EfficientNetV2B0(input_shape=(3, 640, 640), num_classes=0, pretrained=None)
model = yolov8.YOLOV8_N(backbone=bb, classifier_activation=None, pretrained="yolov8_n_E40_0291.h5")
tt = torch_wrapper.Detect(model)
_ = tt.eval()

imm = numpy_image_resize(test_images.dog_cat(), [640, 640]) / 255
preds_torch, torch_out = tt(torch.from_numpy(imm[None]).permute([0, 3, 1, 2]).float())
print(preds_torch.shape, [ii.shape for ii in torch_out]) # This should be same with ultralytics model prediction output
# torch.Size([1, 84, 8400]) [torch.Size([1, 144, 80, 80]), torch.Size([1, 144, 40, 40]), torch.Size([1, 144, 20, 20])]

"""" Convert bboxes xywh to yxyx """
bboxes, scores, labels = preds_torch[0, :4].T, *preds_torch[0, 4:].max(0)
top_left = bboxes[:, [1, 0]] - bboxes[:, [3, 2]] / 2
bboxes = torch.concat([top_left, top_left + bboxes[:, [3, 2]]], dim=-1)

rr, nms_scores = functional.non_max_suppression_with_scores(bboxes, scores, iou_threshold=0.5, score_threshold=0.3)
data.show_image_with_bboxes(imm, bboxes[rr] / 640, labels[rr])
```
**Inference using TF backend**
```py
import tensorflow as tf
from keras_cv_attention_models import efficientnet, yolov8, test_images
from keras_cv_attention_models.coco import data

bb = efficientnet.EfficientNetV2B0(input_shape=(3, 640, 640), num_classes=0, pretrained=None)
model = yolov8.YOLOV8_N(backbone=bb, pretrained="yolov8_n_E40_0291.h5")
model = yolov8.YOLOV8_N(backbone=bb, pretrained="yolov8_n.h5")

imm = test_images.dog_cat()
preds = model(model.preprocess_input(imm))
print(preds.shape)
# (1, 8400, 144)

""" Convert xyxy to yxyx """
preds_bbox = tf.reshape(tf.gather(tf.reshape(preds[:, :, :64], [1, -1, 4, 16]), [1, 0, 3, 2], axis=2), [1, -1, 64])
preds = tf.concat([preds_bbox, preds[:, :, 64:]], axis=-1)
bboxes, labels, confidences = model.decode_predictions(preds)[0]
data.show_image_with_bboxes(imm, bboxes, labels, confidences)
```
Expand Down
8 changes: 4 additions & 4 deletions keras_cv_attention_models/yolov8/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def to_data_loader(data, cfg, imgsz=640, mode="train", batch_size=16, rect_val=F
if mode == "train":
augment, pad, shuffle, rect, batch_size = True, 0, True, False, batch_size
else:
augment, pad, shuffle, rect, batch_size = False, 0.5, False, rect_val, batch_size * 2
augment, pad, shuffle, rect, batch_size = False, 0, False, rect_val, batch_size * 2
dataset = YOLODataset(
img_path=data["train"] if mode == "train" else data["val"],
imgsz=imgsz,
Expand All @@ -39,13 +39,13 @@ def to_data_loader(data, cfg, imgsz=640, mode="train", batch_size=16, rect_val=F
return data_loader


def get_data_loader(dataset_path="coco.yaml", cfg={}, imgsz=640, batch_size=16, rect_val=False):
def get_data_loader(dataset_path="coco.yaml", cfg={}, imgsz=640, batch_size=16, rect_val=False, split="all"):
cfg = get_cfg(DEFAULT_CFG)
cfg.data = dataset_path
cfg.imgsz = imgsz
data = check_det_dataset(dataset_path)
train_loader = to_data_loader(data, cfg, imgsz=imgsz, mode="train", batch_size=batch_size)
val_loader = to_data_loader(data, cfg, imgsz=imgsz, mode="val", batch_size=batch_size, rect_val=rect_val)
train_loader = None if split.lower() == "val" else to_data_loader(data, cfg, imgsz=imgsz, mode="train", batch_size=batch_size)
val_loader = None if split.lower() == "train" else to_data_loader(data, cfg, imgsz=imgsz, mode="val", batch_size=batch_size, rect_val=rect_val)
return train_loader, val_loader


Expand Down
3 changes: 2 additions & 1 deletion keras_cv_attention_models/yolov8/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,14 @@ def __call__(self):
# from ultralytics import YOLO

dataset_path = "coco128.yaml"
train_loader, val_loader = data.get_data_loader(dataset_path=dataset_path, rect_val=True)
train_loader, val_loader = data.get_data_loader(dataset_path=dataset_path, rect_val=True, split="val")
cfg = train.FakeArgs(data=dataset_path, imgsz=640, iou=0.7, conf=0.001, single_cls=False, max_det=300, task="detect", mode="train", split="val", half=False)
cfg.update(degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, flipud=0.0, fliplr=0.5)
cfg.update(mask_ratio=4, overlap_mask=True, project=None, name=None, save_txt=False, save_hybrid=False, save_json=False, plots=False, verbose=True)

# model = YOLO('../ultralytics/ultralytics/models/v8/yolov8n.yaml').model
model = yolov8.YOLOV8_N(input_shape=(3, None, None), classifier_activation=None, pretrained=None)
model = torch_wrapper.Detect(model)
_ = model.eval()
ee = eval.Validator(model, val_loader, cfg=cfg)
ee()
Loading

0 comments on commit 98f2da5

Please sign in to comment.