Skip to content

Commit

Permalink
Implement custom training loop
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Jun 7, 2024
1 parent 1d3e8b2 commit 3e5966d
Show file tree
Hide file tree
Showing 4 changed files with 685 additions and 621 deletions.
3 changes: 2 additions & 1 deletion plugins/train/model/_base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def _output_summary(self) -> None:
else:
# print to logger
print_fn = self._summary_to_log
parent = self.model
for idx, model in enumerate(get_all_sub_models(self.model)):
if idx == 0:
parent = model
Expand All @@ -349,7 +350,7 @@ def _compile_model(self) -> None:

self._loss.configure(self.model)
losses = list(self._loss.functions.values())
self.model.compile(optimizer=optimizer, loss=losses, metrics=losses)
self.model.compile(optimizer=optimizer, loss=losses)
self._state.add_session_loss_names(self._loss.names)
logger.debug("Compiled Model: %s", self.model)

Expand Down
Loading

0 comments on commit 3e5966d

Please sign in to comment.