Skip to content

Commit

Permalink
reset loss_values on refit and test formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
frances-h committed Oct 25, 2023
1 parent 4b07410 commit 02a3016
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
1 change: 1 addition & 0 deletions ctgan/synthesizers/tvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def fit(self, train_data, discrete_columns=()):
list(encoder.parameters()) + list(self.decoder.parameters()),
weight_decay=self.l2scale)

self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss'])
iterator = tqdm(range(self.epochs), disable=(not self.verbose))
if self.verbose:
iterator_description = 'Loss: {loss:.3f}'
Expand Down
24 changes: 17 additions & 7 deletions tests/unit/synthesizer/test_tvae.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,33 @@
"""TVAE unit testing module."""

from unittest.mock import MagicMock, call, patch
from unittest.mock import MagicMock, Mock, call, patch

import pandas as pd

from ctgan.synthesizers import TVAE


class TestTVAE:

@patch('ctgan.synthesizers.tvae._loss_function')
@patch('ctgan.synthesizers.tvae.tqdm')
def test_fit_verbose(self, tqdm_mock):
def test_fit_verbose(self, tqdm_mock, loss_func_mock):
"""Test verbose parameter prints progress bar."""
# Setup
epochs = 10
epochs = 1

def mock_iter():
for i in range(epochs):
yield i

def mock_add(a, b):
mock_loss = Mock()
mock_loss.detach().cpu().item.return_value = 1.23456789
return mock_loss

loss_mock = MagicMock()
loss_mock.__add__ = mock_add
loss_func_mock.return_value = (loss_mock, loss_mock)

iterator_mock = MagicMock()
iterator_mock.__iter__.side_effect = mock_iter
tqdm_mock.return_value = iterator_mock
Expand All @@ -32,6 +41,7 @@ def mock_iter():
synth.fit(train_data)

# Assert
tqdm_mock.assert_called_once_with(range(10), disable=False)
iterator_mock.set_description.call_args_list[0] == call('Loss: 0.000')
assert iterator_mock.set_description.call_count == 11
tqdm_mock.assert_called_once_with(range(epochs), disable=False)
assert iterator_mock.set_description.call_args_list[0] == call('Loss: 0.000')
assert iterator_mock.set_description.call_args_list[1] == call('Loss: 1.235')
assert iterator_mock.set_description.call_count == 2

0 comments on commit 02a3016

Please sign in to comment.