diff --git a/data/hyps/hyp.hic-yolov5s.yaml b/data/hyps/hyp.hic-yolov5s.yaml
new file mode 100644
index 000000000000..80f6e6dd95c4
--- /dev/null
+++ b/data/hyps/hyp.hic-yolov5s.yaml
@@ -0,0 +1,33 @@
+# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
+# hyperparameters for HIC-YOLOv5 for small object detection on VisDrone Dataset
+# python train.py --hyp hyp.hic-yolov5s.yaml
+
+lr0: 0.001  # initial learning rate (SGD=1E-2, Adam=1E-3)
+lrf: 0.01  # final OneCycleLR learning rate (lr0 * lrf)
+momentum: 0.937  # SGD momentum/Adam beta1
+weight_decay: 0.0005  # optimizer weight decay 5e-4
+warmup_epochs: 3.0  # warmup epochs (fractions ok)
+warmup_momentum: 0.8  # warmup initial momentum
+warmup_bias_lr: 0.1  # warmup initial bias lr
+box: 0.05  # box loss gain
+cls: 0.25  # cls loss gain
+cls_pw: 1.0  # cls BCELoss positive_weight
+obj: 0.5  # obj loss gain (scale with pixels)
+obj_pw: 1.0  # obj BCELoss positive_weight
+iou_t: 0.20  # IoU training threshold
+anchor_t: 4.0  # anchor-multiple threshold
+# anchors: 3  # anchors per output layer (0 to ignore)
+fl_gamma: 0.0  # focal loss gamma (efficientDet default gamma=1.5)
+hsv_h: 0.4  # image HSV-Hue augmentation (fraction)
+hsv_s: 0.3  # image HSV-Saturation augmentation (fraction)
+hsv_v: 0.5  # image HSV-Value augmentation (fraction)
+degrees: 0.2  # image rotation (+/- deg)
+translate: 0.1  # image translation (+/- fraction)
+scale: 0.4  # image scale (+/- gain)
+shear: 0.0  # image shear (+/- deg)
+perspective: 0.0  # image perspective (+/- fraction), range 0-0.001
+flipud: 0.0  # image flip up-down (probability)
+fliplr: 0.5  # image flip left-right (probability)
+mosaic: 1.0  # image mosaic (probability)s
+mixup: 0.2  # image mixup (probability)
+copy_paste: 0.1  # segment copy-paste (probability)
diff --git a/models/common.py b/models/common.py
index 75cc4e97bbc7..521551f273d0 100644
--- a/models/common.py
+++ b/models/common.py
@@ -881,3 +881,165 @@ def forward(self, x):
         if isinstance(x, list):
             x = torch.cat(x, 1)
         return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
+
+
+# contributed by @aash1999
+class ChannelAttention(nn.Module):
+
+    def __init__(self, in_planes, ratio=16):
+        """
+        Initialize the Channel Attention module.
+
+        Args:
+            in_planes (int): Number of input channels.
+            ratio (int): Reduction ratio for the hidden channels in the channel attention block.
+        """
+        super().__init__()
+        self.avg_pool = nn.AdaptiveAvgPool2d(1)
+        self.max_pool = nn.AdaptiveMaxPool2d(1)
+        self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
+        self.relu = nn.ReLU()
+        self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, x):
+        """
+        Forward pass of the Channel Attention module.
+
+        Args:
+            x (torch.Tensor): Input tensor.
+
+        Returns:
+            out (torch.Tensor): Output tensor after applying channel attention.
+        """
+        with warnings.catch_warnings():
+            warnings.simplefilter('ignore')
+            avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
+            max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
+            out = self.sigmoid(avg_out + max_out)
+            return out
+
+
+# contributed by @aash1999
+class SpatialAttention(nn.Module):
+
+    def __init__(self, kernel_size=7):
+        """
+        Initialize the Spatial Attention module.
+
+        Args:
+            kernel_size (int): Size of the convolutional kernel for spatial attention.
+        """
+        super().__init__()
+        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
+        padding = 3 if kernel_size == 7 else 1
+        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, x):
+        """
+        Forward pass of the Spatial Attention module.
+
+        Args:
+            x (torch.Tensor): Input tensor.
+
+        Returns:
+            out (torch.Tensor): Output tensor after applying spatial attention.
+        """
+        with warnings.catch_warnings():
+            warnings.simplefilter('ignore')
+            avg_out = torch.mean(x, dim=1, keepdim=True)
+            max_out, _ = torch.max(x, dim=1, keepdim=True)
+            x = torch.cat([avg_out, max_out], dim=1)
+            x = self.conv(x)
+            return self.sigmoid(x)
+
+
+# contributed by @aash1999
+class CBAM(nn.Module):
+    # ch_in, ch_out, shortcut, groups, expansion, ratio, kernel_size
+    def __init__(self, c1, c2, kernel_size=3, shortcut=True, g=1, e=0.5, ratio=16):
+        """
+        Initialize the CBAM (Convolutional Block Attention Module) .
+
+        Args:
+            c1 (int): Number of input channels.
+            c2 (int): Number of output channels.
+            kernel_size (int): Size of the convolutional kernel.
+            shortcut (bool): Whether to use a shortcut connection.
+            g (int): Number of groups for grouped convolutions.
+            e (float): Expansion factor for hidden channels.
+            ratio (int): Reduction ratio for the hidden channels in the channel attention block.
+        """
+        super().__init__()
+        c_ = int(c2 * e)  # hidden channels
+        self.cv1 = Conv(c1, c_, 1, 1)
+        self.cv2 = Conv(c_, c2, 3, 1, g=g)
+        self.add = shortcut and c1 == c2
+        self.channel_attention = ChannelAttention(c2, ratio)
+        self.spatial_attention = SpatialAttention(kernel_size)
+
+    def forward(self, x):
+        """
+        Forward pass of the CBAM .
+
+        Args:
+            x (torch.Tensor): Input tensor.
+
+        Returns:
+            out (torch.Tensor): Output tensor after applying the CBAM bottleneck.
+        """
+        with warnings.catch_warnings():
+            warnings.simplefilter('ignore')
+            x2 = self.cv2(self.cv1(x))
+            out = self.channel_attention(x2) * x2
+            out = self.spatial_attention(out) * out
+            return x + out if self.add else out
+
+
+# contributed by @aash1999
+class Involution(nn.Module):
+
+    def __init__(self, c1, c2, kernel_size, stride):
+        """
+        Initialize the Involution module.
+
+        Args:
+            c1 (int): Number of input channels.
+            c2 (int): Number of output channels.
+            kernel_size (int): Size of the involution kernel.
+            stride (int): Stride for the involution operation.
+        """
+        super().__init__()
+        self.kernel_size = kernel_size
+        self.stride = stride
+        self.c1 = c1
+        reduction_ratio = 1
+        self.group_channels = 16
+        self.groups = self.c1 // self.group_channels
+        self.conv1 = Conv(c1, c1 // reduction_ratio, 1)
+        self.conv2 = Conv(c1 // reduction_ratio, kernel_size ** 2 * self.groups, 1, 1)
+
+        if stride > 1:
+            self.avgpool = nn.AvgPool2d(stride, stride)
+        self.unfold = nn.Unfold(kernel_size, 1, (kernel_size - 1) // 2, stride)
+
+    def forward(self, x):
+        """
+        Forward pass of the Involution module.
+
+        Args:
+            x (torch.Tensor): Input tensor.
+
+        Returns:
+            out (torch.Tensor): Output tensor after applying the involution operation.
+        """
+        with warnings.catch_warnings():
+            warnings.simplefilter('ignore')
+            weight = self.conv2(x)
+            b, c, h, w = weight.shape
+            weight = weight.view(b, self.groups, self.kernel_size ** 2, h, w).unsqueeze(2)
+            out = self.unfold(x).view(b, self.groups, self.group_channels, self.kernel_size ** 2, h, w)
+            out = (weight * out).sum(dim=3).view(b, self.c1, h, w)
+
+            return out
diff --git a/models/hub/yolov5s-cbam-involution.yaml b/models/hub/yolov5s-cbam-involution.yaml
new file mode 100644
index 000000000000..9ac132e1cd78
--- /dev/null
+++ b/models/hub/yolov5s-cbam-involution.yaml
@@ -0,0 +1,60 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 10  # number of classes
+depth_multiple: 0.33  # model depth multiple
+width_multiple: 0.50  # layer channel multiple
+anchors:
+  - [2.9434,4.0435, 3.8626,8.5592, 6.8534, 5.9391]
+  - [10,13, 16,30, 33,23]  # P3/8
+  - [30,61, 62,45, 59,119]  # P4/16
+  - [116,90, 156,198, 373,326]  # P5/32
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 6, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
+   [-1, 3, C3, [1024]],
+   [-1, 3, CBAM, [1024, 3]],
+   [-1, 1, SPPF, [1024, 5]],  # 10
+  ]
+
+# YOLOv5 v6.0 head
+head:
+  [[-1, 1, Involution, [1024, 1, 1]],
+   [-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 15
+
+   [-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [512, False]],  # 19
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 2], 1, Concat, [1]],
+   [-1, 3, C3, [256, False]], # 23 160*160  p2 head
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 19], 1, Concat, [1]],
+   [-1, 3, C3, [512, False]],  # 26  80*80 p3 head
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 15], 1, Concat, [1]],
+   [-1, 3, C3, [256, False]], # 29 40*40 p4 head
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 11], 1, Concat, [1]],
+   [-1, 3, C3, [1024, False]],  # 32 20*20 p5 head
+
+   [[23, 26, 29, 32], 1, Detect, [nc, anchors]],  # Detect(P2, P3, P4, P5)
+  ]
diff --git a/models/yolo.py b/models/yolo.py
index 4f4d567bec73..ad78d1fbd486 100644
--- a/models/yolo.py
+++ b/models/yolo.py
@@ -316,7 +316,7 @@ def parse_model(d, ch):  # model_dict, input_channels(3)
         n = n_ = max(round(n * gd), 1) if n > 1 else n  # depth gain
         if m in {
                 Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
-                BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
+                BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, CBAM, Involution}:
             c1, c2 = ch[f], args[0]
             if c2 != no:  # if not output
                 c2 = make_divisible(c2 * gw, 8)
diff --git a/utils/general.py b/utils/general.py
index 135141e21436..f31ca46f8cff 100644
--- a/utils/general.py
+++ b/utils/general.py
@@ -264,7 +264,8 @@ def init_seeds(seed=0, deterministic=False):
     torch.cuda.manual_seed_all(seed)  # for Multi-GPU, exception safe
     # torch.backends.cudnn.benchmark = True  # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
     if deterministic and check_version(torch.__version__, '1.12.0'):  # https://github.com/ultralytics/yolov5/pull/8213
-        torch.use_deterministic_algorithms(True)
+        # since nn.AdaptiveAvgPool2d doesn't have backward implementation during GPU training
+        torch.use_deterministic_algorithms(False, warn_only=True)
         torch.backends.cudnn.deterministic = True
         os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
         os.environ['PYTHONHASHSEED'] = str(seed)