Skip to content

Commit

Permalink
lib.model.layers to Keras 3
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Mar 14, 2024
1 parent 38feb60 commit 9a9232b
Showing 1 changed file with 33 additions and 41 deletions.
74 changes: 33 additions & 41 deletions lib/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def compute_output_shape(self, input_shape):
input_shape: tuple
The input shape to the layer
"""
if self.data_format == 'channels_last':
if self.data_format == "channels_last":
return (input_shape[0], input_shape[3])
return (input_shape[0], input_shape[1])

Expand All @@ -46,7 +46,7 @@ def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor:

def get_config(self) -> dict[str, T.Any]:
""" Set the Keras config """
config = {'data_format': self.data_format}
config = {"data_format": self.data_format}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))

Expand All @@ -67,10 +67,10 @@ def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor:
tensor
A tensor or list/tuple of tensors
"""
if self.data_format == 'channels_last':
pooled = K.min(inputs, axis=[1, 2])
if self.data_format == "channels_last":
pooled = ops.min(inputs, axis=[1, 2])
else:
pooled = K.min(inputs, axis=[2, 3])
pooled = ops.min(inputs, axis=[2, 3])
return pooled


Expand All @@ -90,10 +90,10 @@ def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor:
tensor
A tensor or list/tuple of tensors
"""
if self.data_format == 'channels_last':
pooled = K.std(inputs, axis=[1, 2])
if self.data_format == "channels_last":
pooled = ops.std(inputs, axis=[1, 2])
else:
pooled = K.std(inputs, axis=[2, 3])
pooled = ops.std(inputs, axis=[2, 3])
return pooled


Expand Down Expand Up @@ -130,16 +130,11 @@ def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor:
:class:`torch.Tensor`
A tensor or list/tuple of tensors
"""
if isinstance(self.size, int):
retval = K.resize_images(inputs,
self.size,
self.size,
"channels_last",
interpolation=self.interpolation)
else:
# Arbitrary resizing
size = int(round(K.int_shape(inputs)[1] * self.size))
retval = tf.image.resize(inputs, (size, size), method=self.interpolation)
size = int(round(inputs.shape[1] * self.size)), int(round(inputs.shape[2] * self.size))
retval = ops.image.resize(inputs,
size,
interpolation=self.interpolation,
data_format="channels_last")
return retval

def compute_output_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]:
Expand All @@ -159,7 +154,7 @@ def compute_output_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]:
An input shape tuple
"""
batch, height, width, channels = input_shape
return (batch, height * self.size, width * self.size, channels)
return (batch, int(round(height * self.size)), int(round(width * self.size)), channels)

def get_config(self) -> dict[str, T.Any]:
"""Returns the config of the layer.
Expand Down Expand Up @@ -285,11 +280,11 @@ def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
input_shape = inputs.shape
if len(input_shape) != 4:
raise ValueError('Inputs should have rank ' +
raise ValueError("Inputs should have rank " +
str(4) +
'; Received input shape:', str(input_shape))
"; Received input shape:", str(input_shape))

if self.data_format == 'channels_first':
if self.data_format == "channels_first":
batch_size, channels, height, width = input_shape
if batch_size is None:
batch_size = -1
Expand All @@ -300,7 +295,7 @@ def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor:
out = ops.reshape(inputs, (batch_size, r_height, r_width, o_channels, height, width))
out = ops.transpose(out, (0, 3, 4, 1, 5, 2))
out = ops.reshape(out, (batch_size, o_channels, o_height, o_width))
elif self.data_format == 'channels_last':
elif self.data_format == "channels_last":
batch_size, height, width, channels = input_shape
if batch_size is None:
batch_size = -1
Expand Down Expand Up @@ -330,11 +325,11 @@ def compute_output_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]:
An input shape tuple
"""
if len(input_shape) != 4:
raise ValueError('Inputs should have rank ' +
raise ValueError("Inputs should have rank " +
str(4) +
'; Received input shape:', str(input_shape))
"; Received input shape:", str(input_shape))

if self.data_format == 'channels_first':
if self.data_format == "channels_first":
height = None
width = None
if input_shape[2] is not None:
Expand All @@ -344,13 +339,13 @@ def compute_output_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]:
channels = input_shape[1] // self.size[0] // self.size[1]

if channels * self.size[0] * self.size[1] != input_shape[1]:
raise ValueError('channels of input and size are incompatible')
raise ValueError("channels of input and size are incompatible")

retval = (input_shape[0],
channels,
height,
width)
elif self.data_format == 'channels_last':
elif self.data_format == "channels_last":
height = None
width = None
if input_shape[1] is not None:
Expand All @@ -360,7 +355,7 @@ def compute_output_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]:
channels = input_shape[3] // self.size[0] // self.size[1]

if channels * self.size[0] * self.size[1] != input_shape[3]:
raise ValueError('channels of input and size are incompatible')
raise ValueError("channels of input and size are incompatible")

retval = (input_shape[0],
height,
Expand All @@ -383,8 +378,8 @@ class name. These are handled by `Network` (one layer of abstraction above).
dict
A python dictionary containing the layer configuration
"""
config = {'size': self.size,
'data_format': self.data_format}
config = {"size": self.size,
"data_format": self.data_format}
base_config = super().get_config()

return dict(list(base_config.items()) + list(config.items()))
Expand Down Expand Up @@ -417,7 +412,7 @@ def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor:
:class:`torch.Tensor`
The output Tensor
"""
return inputs * K.sigmoid(1.702 * inputs)
return inputs * ops.sigmoid(1.702 * inputs)


class ReflectionPadding2D(keras.layers.Layer):
Expand Down Expand Up @@ -524,12 +519,9 @@ def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor:
padding_left = padding_width // 2
padding_right = padding_width - padding_left

return tf.pad(inputs,
[[0, 0],
[padding_top, padding_bot],
[padding_left, padding_right],
[0, 0]],
'REFLECT')
return ops.pad(inputs,
[[0, 0], [padding_top, padding_bot], [padding_left, padding_right], [0, 0]],
mode="reflect")

def get_config(self) -> dict[str, T.Any]:
"""Returns the config of the layer.
Expand All @@ -546,8 +538,8 @@ class name. These are handled by `Network` (one layer of abstraction above).
dict
A python dictionary containing the layer configuration
"""
config = {'stride': self.stride,
'kernel_size': self.kernel_size}
config = {"stride": self.stride,
"kernel_size": self.kernel_size}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))

Expand Down Expand Up @@ -583,7 +575,7 @@ def call(self, inputs, *args, **kwargs):
:class:`torch.Tensor`
A tensor or list/tuple of tensors
"""
return tf.nn.swish(inputs * self.beta)
return ops.nn.swish(inputs * self.beta)

def get_config(self):
"""Returns the config of the layer.
Expand Down

0 comments on commit 9a9232b

Please sign in to comment.