From d0338eab43eebc613fc23af98b1422e0e6cabc6f Mon Sep 17 00:00:00 2001 From: Muhammad Anas Raza Date: Fri, 23 Feb 2024 23:35:37 -0500 Subject: [PATCH] add test for cbam --- README.md | 4 ++++ k3_addons/layers/attention/cbam_test.py | 28 +++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/README.md b/README.md index 2198051..9e3676d 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,10 @@ # k3-addons: Additional multi-backend functionality for Keras 3. ![Logo](.assets/k-addons.png) +# Installation + +`pip install k3-addons` + # Includes: - Layers - Pooling: diff --git a/k3_addons/layers/attention/cbam_test.py b/k3_addons/layers/attention/cbam_test.py index e69de29..9a5957a 100644 --- a/k3_addons/layers/attention/cbam_test.py +++ b/k3_addons/layers/attention/cbam_test.py @@ -0,0 +1,28 @@ +import pytest +import keras +from keras import ops + + +from k3_addons.layers.attention.cbam import ChannelAttention, SpatialAttention, CBAMBlock + + +@pytest.mark.parametrize("input_shape", [(1, 10, 10, 256), (1, 14, 14, 128)]) +def test_channel_attention(input_shape): + inputs = keras.random.normal(input_shape) + layer = ChannelAttention() + out = layer(inputs) + assert ops.shape(out) == (1, 1, 1,) + (input_shape[-1],) + +@pytest.mark.parametrize("input_shape", [(1, 10, 10, 256), (1, 14, 14, 128)]) +def test_spatial_attention(input_shape): + inputs = keras.random.normal(input_shape) + layer = SpatialAttention() + out = layer(inputs) + assert ops.shape(out) == input_shape[:-1] + (1,) # Dynamic assertion + +@pytest.mark.parametrize("input_shape", [(1, 10, 10, 256), (1, 14, 14, 128)]) +def test_cbam(input_shape): + inputs = keras.random.normal(input_shape) # Modify input shape + layer = CBAMBlock() + out = layer(inputs) + assert ops.shape(out) == input_shape # Output shape should remain the same \ No newline at end of file