Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/awsaf49/gcvit-tf
Browse files Browse the repository at this point in the history
  • Loading branch information
awsaf49 committed Dec 24, 2023
2 parents 1b292fb + 7ac8223 commit 174ea58
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 20 deletions.
9 changes: 9 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "Rahman"
given-names: "Md Awsafur"
title: "Tensorflow 2.0 Implementation of GCViT: Global Context Vision Transformer"
version: 1.1.5
date-released: 2020-07-20
url: "https://awsaf49.github.io/gcvit-tf"
6 changes: 3 additions & 3 deletions gcvit/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def build(self, input_shape):
dim * self.qkv_size, use_bias=self.qkv_bias, name="qkv"
)
self.relative_position_bias_table = self.add_weight(
"relative_position_bias_table",
name="relative_position_bias_table",
shape=[
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
self.num_heads,
Expand All @@ -50,7 +50,7 @@ def build(self, input_shape):
self.proj_dropout, name="proj_drop"
)
self.softmax = tf.keras.layers.Activation("softmax", name="softmax")
self.relative_position_index = self.get_relative_position_index()
# self.relative_position_index = self.get_relative_position_index()
super().build(input_shape)

def get_relative_position_index(self):
Expand Down Expand Up @@ -101,7 +101,7 @@ def call(self, inputs, **kwargs):
attn = q @ tf.transpose(k, perm=[0, 1, 3, 2])
relative_position_bias = tf.gather(
self.relative_position_bias_table,
tf.reshape(self.relative_position_index, shape=[-1]),
tf.reshape(self.get_relative_position_index(), shape=[-1]),
)
relative_position_bias = tf.reshape(
relative_position_bias,
Expand Down
4 changes: 2 additions & 2 deletions gcvit/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ def build(self, input_shape):
)
if self.layer_scale is not None:
self.gamma1 = self.add_weight(
"gamma1",
name="gamma1",
shape=[C],
initializer=tf.keras.initializers.Constant(self.layer_scale),
trainable=True,
dtype=self.dtype,
)
self.gamma2 = self.add_weight(
"gamma2",
name="gamma2",
shape=[C],
initializer=tf.keras.initializers.Constant(self.layer_scale),
trainable=True,
Expand Down
26 changes: 13 additions & 13 deletions gcvit/layers/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ def build(self, input_shape):
self.avg_pool = AdaptiveAveragePooling2D(1, name="avg_pool")
self.fc = [
tf.keras.layers.Dense(
int(inp * self.expansion), use_bias=False, name="fc/0"
int(inp * self.expansion), use_bias=False, name="fc_0"
),
tf.keras.layers.Activation("gelu", name="fc/1"),
tf.keras.layers.Dense(self.oup, use_bias=False, name="fc/2"),
tf.keras.layers.Activation("sigmoid", name="fc/3"),
tf.keras.layers.Activation("gelu", name="fc_1"),
tf.keras.layers.Dense(self.oup, use_bias=False, name="fc_2"),
tf.keras.layers.Activation("sigmoid", name="fc_3"),
]
super().build(input_shape)

Expand Down Expand Up @@ -111,17 +111,17 @@ def build(self, input_shape):
strides=1,
padding="valid",
use_bias=False,
name="conv/0",
name="conv_0",
),
tf.keras.layers.Activation("gelu", name="conv/1"),
SE(name="conv/2"),
tf.keras.layers.Activation("gelu", name="conv_1"),
SE(name="conv_2"),
tf.keras.layers.Conv2D(
dim,
kernel_size=1,
strides=1,
padding="valid",
use_bias=False,
name="conv/3",
name="conv_3",
),
]
self.reduction = tf.keras.layers.Conv2D(
Expand Down Expand Up @@ -179,17 +179,17 @@ def build(self, input_shape):
strides=1,
padding="valid",
use_bias=False,
name="conv/0",
name="conv_0",
),
tf.keras.layers.Activation("gelu", name="conv/1"),
SE(name="conv/2"),
tf.keras.layers.Activation("gelu", name="conv_1"),
SE(name="conv_2"),
tf.keras.layers.Conv2D(
dim,
kernel_size=1,
strides=1,
padding="valid",
use_bias=False,
name="conv/3",
name="conv_3",
),
]
if not self.keep_dim:
Expand Down Expand Up @@ -237,7 +237,7 @@ def __init__(self, keep_dims=False, **kwargs):

def build(self, input_shape):
self.to_q_global = [
FeatExtract(keep_dim, name=f"to_q_global/{i}")
FeatExtract(keep_dim, name=f"to_q_global_{i}")
for i, keep_dim in enumerate(self.keep_dims)
]
super().build(input_shape)
Expand Down
2 changes: 1 addition & 1 deletion gcvit/layers/level.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def build(self, input_shape):
attn_drop=self.attn_drop,
path_drop=path_drop[i],
layer_scale=self.layer_scale,
name=f"blocks/{i}",
name=f"blocks_{i}",
)
for i in range(self.depth)
]
Expand Down
2 changes: 1 addition & 1 deletion gcvit/models/gcvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __init__(
path_drop=path_drop,
layer_scale=layer_scale,
resize_query=resize_query,
name=f"levels/{i}",
name=f"levels_{i}",
)
self.levels.append(level)
self.norm = tf.keras.layers.LayerNormalization(
Expand Down
Binary file added image/level.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 174ea58

Please sign in to comment.