Skip to content

Commit

Permalink
black applied
Browse files Browse the repository at this point in the history
  • Loading branch information
gallenaxel committed Mar 7, 2024
1 parent 857aee3 commit 03b39e4
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
14 changes: 8 additions & 6 deletions baler/modules/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ def encoder_saver(model, model_path: str) -> None:
Returns:
None: Saved encoder state dictionary as `.pt` file.
"""
if hasattr(model.encoder,'state_dict'):
torch.save(model.encoder.state_dict(),model_path)
else: model.save_encoder(model_path)
if hasattr(model.encoder, "state_dict"):
torch.save(model.encoder.state_dict(), model_path)
else:
model.save_encoder(model_path)


def decoder_saver(model, model_path: str) -> None:
Expand All @@ -72,9 +73,10 @@ def decoder_saver(model, model_path: str) -> None:
Returns:
None: Saved decoder state dictionary as `.pt` file.
"""
if hasattr(model.decoder,'state_dict'):
torch.save(model.decoder.state_dict(),model_path)
else: model.save_decoder(model_path)
if hasattr(model.decoder, "state_dict"):
torch.save(model.decoder.state_dict(), model_path)
else:
model.save_decoder(model_path)


def initialise_model(model_name: str):
Expand Down
20 changes: 15 additions & 5 deletions baler/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,11 +713,11 @@ def get_final_layer_dims(self):
def set_final_layer_dims(self, conv_op_shape):
self.conv_op_shape = conv_op_shape


class PJ_Conv_AE_FPGA(nn.Module):
def __init__(self, n_features, z_dim=10, *args, **kwargs):
super(PJ_Conv_AE_FPGA, self).__init__(*args, **kwargs)


# Encoder layers
self.en1 = nn.Conv2d(1, 20, kernel_size=5, stride=2, padding=2)
self.en_act1 = nn.ReLU()
Expand Down Expand Up @@ -766,17 +766,27 @@ def forward(self, x):
return decoded

def get_final_layer_dims(self):
return
return

def set_final_layer_dims(self, conv_op_shape):
self.conv_op_shape = conv_op_shape

def save_encoder(self, file_path):
# Create an instance of the encoder
encoder_instance = nn.Sequential(self.en1, self.en_act1, self.en2, self.en_act2, self.en3, self.en4)
encoder_instance = nn.Sequential(
self.en1, self.en_act1, self.en2, self.en_act2, self.en3, self.en4
)
torch.save(encoder_instance.state_dict(), file_path)

def save_decoder(self, file_path):
# Create an instance of the decoder
decoder_instance = nn.Sequential(self.de1, self.de_act1, self.de2, self.de_unflatten, self.de_conv1, self.de_conv2, self.de_act2)
torch.save(decoder_instance.state_dict(), file_path)
decoder_instance = nn.Sequential(
self.de1,
self.de_act1,
self.de2,
self.de_unflatten,
self.de_conv1,
self.de_conv2,
self.de_act2,
)
torch.save(decoder_instance.state_dict(), file_path)

0 comments on commit 03b39e4

Please sign in to comment.