diff --git a/.github/workflows/dependency_checker.yml b/.github/workflows/dependency_checker.yml new file mode 100644 index 00000000..0afe300e --- /dev/null +++ b/.github/workflows/dependency_checker.yml @@ -0,0 +1,29 @@ +name: Dependency Checker +on: + schedule: + - cron: '0 0 * * 1-5' + workflow_dispatch: +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.9 + uses: actions/setup-python@v4 + with: + python-version: 3.9 + - name: Install dependencies + run: | + python -m pip install .[dev] + make check-deps OUTPUT_FILEPATH=latest_requirements.txt + - name: Create pull request + id: cpr + uses: peter-evans/create-pull-request@v4 + with: + token: ${{ secrets.GH_ACCESS_TOKEN }} + commit-message: Update latest dependencies + title: Automated Latest Dependency Updates + body: "This is an auto-generated PR with **latest** dependency updates." + branch: latest-dependency-update + branch-suffix: short-commit-hash + base: main diff --git a/.github/workflows/readme.yml b/.github/workflows/readme.yml index b3a44130..03a14ac5 100644 --- a/.github/workflows/readme.yml +++ b/.github/workflows/readme.yml @@ -22,5 +22,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install invoke rundoc . + python -m pip install tomli + python -m pip install packaging - name: Run the README.md run: invoke readme diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 893eb130..ca9fdba2 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -219,9 +219,9 @@ This will perform the following actions: 2. Bump the current version to the next release candidate, ``X.Y.Z.dev(N+1)`` After this is done, the new pre-release can be installed by including the ``dev`` section in the -dependency specification, either in ``setup.py``:: +dependency specification, either in ``pyproject.toml``:: - install_requires = [ + dependencies = [ ... 'ctgan>=X.Y.Z.dev', ... diff --git a/HISTORY.md b/HISTORY.md index bbc48009..d1f84a61 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,36 @@ # History +## v0.9.1 - 2024-03-14 + +This release changes the `loss_values` attribute of a CTGAN model to contain floats instead of `torch.Tensors`. + +### New Features + +* Return loss values as float values not PyTorch objects - Issue [#332](https://github.com/sdv-dev/CTGAN/issues/332) by @fealho + +### Maintenance + +* Transition from using setup.py to pyproject.toml to specify project metadata - Issue [#333](https://github.com/sdv-dev/CTGAN/issues/333) by @R-Palazzo +* Remove bumpversion and use bump-my-version - Issue [#334](https://github.com/sdv-dev/CTGAN/issues/334) by @R-Palazzo +* Add dependency checker - Issue [#336](https://github.com/sdv-dev/CTGAN/issues/336) by @amontanez24 + +## v0.9.0 - 2024-02-13 + +This release makes CTGAN sampling more efficient by saving the frequency of each categorical value. + +### New Features + +* Improve DataSampler efficiency - Issue [#327] ((https://github.com/sdv-dev/CTGAN/issue/327)) by @fealho + +## v0.8.0 - 2023-11-13 + +This release adds a progress bar that will show when setting the `verbose` parameter to `True` +when initializing `TVAE`. + +### New Features + +* Add verbosity TVAE (progress bar + save the loss values) - Issue [#300]((https://github.com/sdv-dev/CTGAN/issues/300) by @frances-h + ## v0.7.5 - 2023-10-05 This release adds a progress bar that will show when setting the `verbose` parameter to True when initializing `CTGAN`. It also removes a warning that was showing. diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 469520f5..00000000 --- a/MANIFEST.in +++ /dev/null @@ -1,11 +0,0 @@ -include AUTHORS.rst -include CONTRIBUTING.rst -include HISTORY.md -include LICENSE -include README.md - -recursive-include tests * -recursive-exclude * __pycache__ -recursive-exclude * *.py[co] - -recursive-include docs *.md *.rst conf.py Makefile make.bat *.jpg *.png *.gif diff --git a/Makefile b/Makefile index ea6a8132..106671de 100644 --- a/Makefile +++ b/Makefile @@ -76,16 +76,9 @@ install-test: clean-build clean-pyc ## install the package and test dependencies install-develop: clean-build clean-pyc ## install the package in editable mode and dependencies for development pip install -e .[dev] -MINIMUM := $(shell sed -n '/install_requires = \[/,/]/p' setup.py | head -n-1 | tail -n+2 | sed 's/ *\(.*\),$?$$/\1/g' | tr '>' '=') - -.PHONY: install-minimum -install-minimum: ## install the minimum supported versions of the package dependencies - pip install $(MINIMUM) - # LINT TARGETS - .PHONY: lint lint: ## check style with flake8 and isort invoke lint @@ -138,8 +131,7 @@ coverage: ## check code coverage quickly with the default Python .PHONY: dist dist: clean ## builds source and wheel package - python setup.py sdist - python setup.py bdist_wheel + python -m build --wheel --sdist ls -l dist .PHONY: publish-confirm @@ -161,34 +153,34 @@ publish: dist publish-confirm ## package and upload a release bumpversion-release: ## Merge main to stable and bumpversion release git checkout stable || git checkout -b stable git merge --no-ff main -m"make release-tag: Merge branch 'main' into stable" - bumpversion release + bump-my-version bump release git push --tags origin stable .PHONY: bumpversion-release-test bumpversion-release-test: ## Merge main to stable and bumpversion release git checkout stable || git checkout -b stable git merge --no-ff main -m"make release-tag: Merge branch 'main' into stable" - bumpversion release --no-tag + bump-my-version bump release --no-tag @echo git push --tags origin stable .PHONY: bumpversion-patch bumpversion-patch: ## Merge stable to main and bumpversion patch git checkout main git merge stable - bumpversion --no-tag patch + bump-my-version bump --no-tag patch git push .PHONY: bumpversion-candidate bumpversion-candidate: ## Bump the version to the next candidate - bumpversion candidate --no-tag + bump-my-version bump candidate --no-tag .PHONY: bumpversion-minor bumpversion-minor: ## Bump the version the next minor skipping the release - bumpversion --no-tag minor + bump-my-version bump --no-tag minor .PHONY: bumpversion-major bumpversion-major: ## Bump the version the next major skipping the release - bumpversion --no-tag major + bump-my-version bump --no-tag major .PHONY: bumpversion-revert bumpversion-revert: ## Undo a previous bumpversion-release @@ -238,3 +230,10 @@ release-minor: check-release bumpversion-minor release .PHONY: release-major release-major: check-release bumpversion-major release + +# Dependency targets + +.PHONY: check-deps +check-deps: + $(eval allow_list='numpy=|pandas=|scikit-learn=|tqdm=|torch=|rdt=') + pip freeze | grep -v "CTGAN.git" | grep -E $(allow_list) > $(OUTPUT_FILEPATH) diff --git a/ctgan/__init__.py b/ctgan/__init__.py index f9b6c353..ce3d62fa 100644 --- a/ctgan/__init__.py +++ b/ctgan/__init__.py @@ -4,7 +4,7 @@ __author__ = 'DataCebo, Inc.' __email__ = 'info@sdv.dev' -__version__ = '0.7.5' +__version__ = '0.9.1.dev1' from ctgan.demo import load_demo from ctgan.synthesizers.ctgan import CTGAN diff --git a/ctgan/data_sampler.py b/ctgan/data_sampler.py index 5cbf339d..53a5a32b 100644 --- a/ctgan/data_sampler.py +++ b/ctgan/data_sampler.py @@ -7,7 +7,7 @@ class DataSampler(object): """DataSampler samples the conditional vector and corresponding data for CTGAN.""" def __init__(self, data, output_info, log_frequency): - self._data = data + self._data_length = len(data) def is_discrete_column(column_info): return (len(column_info) == 1 @@ -115,33 +115,34 @@ def sample_original_condvec(self, batch): if self._n_discrete_columns == 0: return None + category_freq = self._discrete_column_category_prob.flatten() + category_freq = category_freq[category_freq != 0] + category_freq = category_freq / np.sum(category_freq) + col_idxs = np.random.choice(np.arange(len(category_freq)), batch, p=category_freq) cond = np.zeros((batch, self._n_categories), dtype='float32') - - for i in range(batch): - row_idx = np.random.randint(0, len(self._data)) - col_idx = np.random.randint(0, self._n_discrete_columns) - matrix_st = self._discrete_column_matrix_st[col_idx] - matrix_ed = matrix_st + self._discrete_column_n_category[col_idx] - pick = np.argmax(self._data[row_idx, matrix_st:matrix_ed]) - cond[i, pick + self._discrete_column_cond_st[col_idx]] = 1 + cond[np.arange(batch), col_idxs] = 1 return cond - def sample_data(self, n, col, opt): + def sample_data(self, data, n, col, opt): """Sample data from original training data satisfying the sampled conditional vector. + Args: + data: + The training data. Returns: - n rows of matrix data. + n: + n rows of matrix data. """ if col is None: - idx = np.random.randint(len(self._data), size=n) - return self._data[idx] + idx = np.random.randint(len(data), size=n) + return data[idx] idx = [] for c, o in zip(col, opt): idx.append(np.random.choice(self._rid_by_cat_cols[c][o])) - return self._data[idx] + return data[idx] def dim_cond_vec(self): """Return the total number of categories.""" diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index 82dd3aba..c858ea51 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -175,8 +175,7 @@ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_di self._transformer = None self._data_sampler = None self._generator = None - - self.loss_values = pd.DataFrame(columns=['Epoch', 'Generator Loss', 'Distriminator Loss']) + self.loss_values = None @staticmethod def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): @@ -355,7 +354,8 @@ def fit(self, train_data, discrete_columns=(), epochs=None): condvec = self._data_sampler.sample_condvec(self._batch_size) if condvec is None: c1, m1, col, opt = None, None, None, None - real = self._data_sampler.sample_data(self._batch_size, col, opt) + real = self._data_sampler.sample_data( + train_data, self._batch_size, col, opt) else: c1, m1, col, opt = condvec c1 = torch.from_numpy(c1).to(self._device) @@ -365,7 +365,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None): perm = np.arange(self._batch_size) np.random.shuffle(perm) real = self._data_sampler.sample_data( - self._batch_size, col[perm], opt[perm]) + train_data, self._batch_size, col[perm], opt[perm]) c2 = c1[perm] fake = self._generator(fakez) @@ -422,8 +422,8 @@ def fit(self, train_data, discrete_columns=(), epochs=None): loss_g.backward() optimizerG.step() - generator_loss = loss_g.detach().cpu() - discriminator_loss = loss_d.detach().cpu() + generator_loss = loss_g.detach().cpu().item() + discriminator_loss = loss_d.detach().cpu().item() epoch_loss_df = pd.DataFrame({ 'Epoch': [i], diff --git a/ctgan/synthesizers/tvae.py b/ctgan/synthesizers/tvae.py index 8dafd4e7..ba4d4855 100644 --- a/ctgan/synthesizers/tvae.py +++ b/ctgan/synthesizers/tvae.py @@ -1,11 +1,13 @@ """TVAE module.""" import numpy as np +import pandas as pd import torch from torch.nn import Linear, Module, Parameter, ReLU, Sequential from torch.nn.functional import cross_entropy from torch.optim import Adam from torch.utils.data import DataLoader, TensorDataset +from tqdm import tqdm from ctgan.data_transformer import DataTransformer from ctgan.synthesizers.base import BaseSynthesizer, random_state @@ -112,7 +114,8 @@ def __init__( batch_size=500, epochs=300, loss_factor=2, - cuda=True + cuda=True, + verbose=False ): self.embedding_dim = embedding_dim @@ -123,6 +126,8 @@ def __init__( self.batch_size = batch_size self.loss_factor = loss_factor self.epochs = epochs + self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss']) + self.verbose = verbose if not cuda or not torch.cuda.is_available(): device = 'cpu' @@ -159,7 +164,15 @@ def fit(self, train_data, discrete_columns=()): list(encoder.parameters()) + list(self.decoder.parameters()), weight_decay=self.l2scale) - for i in range(self.epochs): + self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss']) + iterator = tqdm(range(self.epochs), disable=(not self.verbose)) + if self.verbose: + iterator_description = 'Loss: {loss:.3f}' + iterator.set_description(iterator_description.format(loss=0)) + + for i in iterator: + loss_values = [] + batch = [] for id_, data in enumerate(loader): optimizerAE.zero_grad() real = data[0].to(self._device) @@ -176,6 +189,26 @@ def fit(self, train_data, discrete_columns=()): optimizerAE.step() self.decoder.sigma.data.clamp_(0.01, 1.0) + batch.append(id_) + loss_values.append(loss.detach().cpu().item()) + + epoch_loss_df = pd.DataFrame({ + 'Epoch': [i] * len(batch), + 'Batch': batch, + 'Loss': loss_values + }) + if not self.loss_values.empty: + self.loss_values = pd.concat( + [self.loss_values, epoch_loss_df] + ).reset_index(drop=True) + else: + self.loss_values = epoch_loss_df + + if self.verbose: + iterator.set_description( + iterator_description.format( + loss=loss.detach().cpu().item())) + @random_state def sample(self, samples): """Sample data similar to the training data. diff --git a/latest_requirements.txt b/latest_requirements.txt new file mode 100644 index 00000000..33b0963a --- /dev/null +++ b/latest_requirements.txt @@ -0,0 +1,6 @@ +numpy==1.26.4 +pandas==2.2.1 +rdt==1.10.0 +scikit-learn==1.4.1.post1 +torch==2.2.1 +tqdm==4.66.2 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..27d0197d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,187 @@ +[project] +name = 'ctgan' +description = 'Create tabular synthetic data using a conditional GAN' +authors = [{ name = 'DataCebo, Inc.', email = 'info@sdv.dev' }] +classifiers = [ + 'Development Status :: 2 - Pre-Alpha', + 'Intended Audience :: Developers', + 'License :: Free for non-commercial use', + 'Natural Language :: English', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', +] +keywords = ['ctgan', 'CTGAN'] +dynamic = ['version'] +license = { text = 'BSL-1.1' } +requires-python = '>=3.8,<3.12' +readme = 'README.md' +dependencies = [ + "numpy>=1.20.0;python_version<'3.10'", + "numpy>=1.23.3;python_version>='3.10'", + "pandas>=1.1.3;python_version<'3.10'", + "pandas>=1.3.4;python_version>='3.10' and python_version<'3.11'", + "pandas>=1.5.0;python_version>='3.11'", + "scikit-learn>=0.24;python_version<'3.10'", + "scikit-learn>=1.1.3;python_version>='3.10'", + "torch>=1.8.0;python_version<'3.10'", + "torch>=1.11.0;python_version>='3.10' and python_version<'3.11'", + "torch>=2.0.0;python_version>='3.11'", + 'tqdm>=4.15', + 'rdt>=1.6.1', +] + +[project.urls] +"Source Code"= "https://github.com/sdv-dev/CTGAN/" +"Issue Tracker" = "https://github.com/sdv-dev/CTGAN/issues" +"Changes" = "https://github.com/sdv-dev/CTGAN/blob/main/HISTORY.md" +"Twitter" = "https://twitter.com/sdv_dev" +"Chat" = "https://bit.ly/sdv-slack-invite" + +[project.entry-points] +ctgan = { main = 'ctgan.cli.__main__:main' } + +[project.optional-dependencies] +test = [ + 'pytest>=3.4.2', + 'pytest-rerunfailures>=9.1.1,<10', + 'pytest-cov>=2.6.0', + 'rundoc>=0.4.3,<0.5', + 'pytest-runner >= 2.11.1', + 'tomli>=2.0.0,<3', +] +dev = [ + 'ctgan[test]', + + # general + 'pip>=9.0.1', + 'build>=1.0.0,<2', + 'bump-my-version>=0.18.3,<1', + 'watchdog>=0.8.3,<0.11', + + # style check + 'flake8>=3.7.7,<4', + 'isort>=4.3.4,<5', + 'dlint>=0.11.0,<0.12', # code security addon for flake8 + 'flake8-debugger>=4.0.0,<4.1', + 'flake8-mock>=0.3,<0.4', + 'flake8-mutable>=1.2.0,<1.3', + 'flake8-absolute-import>=1.0,<2', + 'flake8-multiline-containers>=0.0.18,<0.1', + 'flake8-print>=4.0.0,<4.1', + 'flake8-quotes>=3.3.0,<4', + 'flake8-fixme>=1.1.1,<1.2', + 'flake8-expression-complexity>=0.0.9,<0.1', + 'flake8-eradicate>=1.1.0,<1.2', + 'flake8-builtins>=1.5.3,<1.6', + 'flake8-variables-names>=0.0.4,<0.1', + 'pandas-vet>=0.2.2,<0.3', + 'flake8-comprehensions>=3.6.1,<3.7', + 'dlint>=0.11.0,<0.12', + 'flake8-docstrings>=1.5.0,<2', + 'flake8-sfs>=0.0.3,<0.1', + 'flake8-pytest-style>=1.5.0,<2', + + # fix style issues + 'autoflake>=1.1,<2', + 'autopep8>=1.4.3,<1.6', + + # distribute on PyPI + 'twine>=1.10.0,<4', + 'wheel>=0.30.0', + + # Advanced testing + 'coverage>=4.5.1,<6', + 'tox>=2.9.1,<4', + + 'invoke', +] + +[tool.setuptools] +include-package-data = true +license-files = ['LICENSE'] + +[tool.setuptools.packages.find] +include = ['ctgan', 'ctgan.*'] +namespaces = false + +[tool.setuptools.package-data] +'*' = [ + 'AUTHORS.rst', + 'CONTRIBUTING.rst', + 'HISTORY.md', + 'README.md', + '*.md', + '*.rst', + 'conf.py', + 'Makefile', + 'make.bat', + '*.jpg', + '*.png', + '*.gif' +] +'tests' = ['*'] + +[tool.setuptools.exclude-package-data] +'*' = [ + '* __pycache__', + '*.py[co]', +] + +[tool.setuptools.dynamic] +version = {attr = 'ctgan.__version__'} + +[tool.isort] +include_trailing_comment = true +line_length = 99 +lines_between_types = 0 +multi_line_output = 4 +not_skip = ['__init__.py'] +use_parentheses = true + +[tool.pydocstyle] +convention = 'google' +add-ignore = ['D107', 'D407', 'D417'] + +[tool.pytest.ini_options] +collect_ignore = ['pyproject.toml'] + +[tool.bumpversion] +current_version = "0.9.1.dev1" +parse = '(?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))?' +serialize = [ + '{major}.{minor}.{patch}.{release}{candidate}', + '{major}.{minor}.{patch}' +] +search = '{current_version}' +replace = '{new_version}' +regex = false +ignore_missing_version = false +tag = true +sign_tags = false +tag_name = 'v{new_version}' +tag_message = 'Bump version: {current_version} → {new_version}' +allow_dirty = false +commit = true +message = 'Bump version: {current_version} → {new_version}' +commit_args = '' + +[tool.bumpversion.parts.release] +first_value = 'dev' +optional_value = 'release' +values = [ + 'dev', + 'release' +] + +[[tool.bumpversion.files]] +filename = "ctgan/__init__.py" +search = "__version__ = '{current_version}'" +replace = "__version__ = '{new_version}'" + +[build-system] +requires = ['setuptools', 'wheel'] +build-backend = 'setuptools.build_meta' diff --git a/setup.cfg b/setup.cfg index 021c34fb..6f3fc634 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,32 +1,3 @@ -[bumpversion] -current_version = 0.7.5 -commit = True -tag = True -parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))? -serialize = - {major}.{minor}.{patch}.{release}{candidate} - {major}.{minor}.{patch} - -[bumpversion:part:release] -optional_value = release -first_value = dev -values = - dev - release - -[bumpversion:part:candidate] - -[bumpversion:file:setup.py] -search = version='{current_version}' -replace = version='{new_version}' - -[bumpversion:file:ctgan/__init__.py] -search = __version__ = '{current_version}' -replace = __version__ = '{new_version}' - -[bdist_wheel] -universal = 1 - [flake8] convention = google max-line-length = 99 @@ -39,17 +10,6 @@ extend-ignore = D107, # Missing docstring in __init__ per-file-ignores = ctgan/data.py:T001 -[isort] -include_trailing_comment = True -line_length = 99 -lines_between_types = 0 -multi_line_output = 4 -not_skip = __init__.py -use_parentheses = True - [aliases] test = pytest -[tool:pytest] -collect_ignore = ['setup.py'] - diff --git a/setup.py b/setup.py deleted file mode 100644 index dc7ad639..00000000 --- a/setup.py +++ /dev/null @@ -1,124 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -"""The setup script.""" - -from setuptools import find_packages, setup - -with open('README.md', encoding='utf-8') as readme_file: - readme = readme_file.read() - -with open('HISTORY.md', encoding='utf-8') as history_file: - history = history_file.read() - -install_requires = [ - "numpy>=1.20.0,<2;python_version<'3.10'", - "numpy>=1.23.3,<2;python_version>='3.10'", - "pandas>=1.1.3;python_version<'3.10'", - "pandas>=1.3.4;python_version>='3.10' and python_version<'3.11'", - "pandas>=1.5.0;python_version>='3.11'", - "scikit-learn>=1.1.3,<2;python_version>='3.10'", - "torch>=1.8.0;python_version<'3.10'", - "torch>=1.11.0;python_version>='3.10' and python_version<'3.11'", - "torch>=2.0.0;python_version>='3.11'", - 'tqdm>=4.15,<5', - 'rdt>=1.6.1,<2.0', -] - -setup_requires = [ - 'pytest-runner>=2.11.1', -] - -tests_require = [ - 'pytest>=3.4.2', - 'pytest-rerunfailures>=9.1.1,<10', - 'pytest-cov>=2.6.0', - 'rundoc>=0.4.3,<0.5', -] - -development_requires = [ - # general - 'pip>=9.0.1', - 'bumpversion>=0.5.3,<0.6', - 'watchdog>=0.8.3,<0.11', - - # style check - 'flake8>=3.7.7,<4', - 'isort>=4.3.4,<5', - 'dlint>=0.11.0,<0.12', # code security addon for flake8 - 'flake8-debugger>=4.0.0,<4.1', - 'flake8-mock>=0.3,<0.4', - 'flake8-mutable>=1.2.0,<1.3', - 'flake8-absolute-import>=1.0,<2', - 'flake8-multiline-containers>=0.0.18,<0.1', - 'flake8-print>=4.0.0,<4.1', - 'flake8-quotes>=3.3.0,<4', - 'flake8-fixme>=1.1.1,<1.2', - 'flake8-expression-complexity>=0.0.9,<0.1', - 'flake8-eradicate>=1.1.0,<1.2', - 'flake8-builtins>=1.5.3,<1.6', - 'flake8-variables-names>=0.0.4,<0.1', - 'pandas-vet>=0.2.2,<0.3', - 'flake8-comprehensions>=3.6.1,<3.7', - 'dlint>=0.11.0,<0.12', - 'flake8-docstrings>=1.5.0,<2', - 'flake8-sfs>=0.0.3,<0.1', - 'flake8-pytest-style>=1.5.0,<2', - - # fix style issues - 'autoflake>=1.1,<2', - 'autopep8>=1.4.3,<1.6', - - # distribute on PyPI - 'twine>=1.10.0,<4', - 'wheel>=0.30.0', - - # Advanced testing - 'coverage>=4.5.1,<6', - 'tox>=2.9.1,<4', - - 'invoke', -] - -setup( - author='DataCebo, Inc.', - author_email='info@sdv.dev', - classifiers=[ - 'Development Status :: 2 - Pre-Alpha', - 'Intended Audience :: Developers', - 'License :: Free for non-commercial use', - 'Natural Language :: English', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - ], - description='Create tabular synthetic data using a conditional GAN', - entry_points={ - 'console_scripts': [ - 'ctgan=ctgan.__main__:main' - ], - }, - extras_require={ - 'test': tests_require, - 'dev': development_requires + tests_require, - }, - install_package_data=True, - install_requires=install_requires, - license='BSL-1.1', - long_description=readme + '\n\n' + history, - long_description_content_type='text/markdown', - include_package_data=True, - keywords='ctgan CTGAN', - name='ctgan', - packages=find_packages(include=['ctgan', 'ctgan.*']), - python_requires='>=3.8,<3.12', - setup_requires=setup_requires, - test_suite='tests', - tests_require=tests_require, - url='https://github.com/sdv-dev/CTGAN', - version='0.7.5', - zip_safe=False, -) diff --git a/tasks.py b/tasks.py index f5782f0f..26adfd76 100644 --- a/tasks.py +++ b/tasks.py @@ -1,15 +1,15 @@ -import glob import inspect import operator import os -import re -import pkg_resources -import platform import shutil import stat +import sys from pathlib import Path +import tomli from invoke import task +from packaging.requirements import Requirement +from packaging.version import Version COMPARISONS = { '>=': operator.ge, @@ -52,49 +52,45 @@ def readme(c): shutil.rmtree(test_path) -def _validate_python_version(line): - is_valid = True - for python_version_match in re.finditer(r"python_version(<=?|>=?|==)\'(\d\.?)+\'", line): - python_version = python_version_match.group(0) - comparison = re.search(r'(>=?|<=?|==)', python_version).group(0) - version_number = python_version.split(comparison)[-1].replace("'", "") - comparison_function = COMPARISONS[comparison] - is_valid = is_valid and comparison_function( - pkg_resources.parse_version(platform.python_version()), - pkg_resources.parse_version(version_number), - ) +def _get_minimum_versions(dependencies, python_version): + min_versions = {} + for dependency in dependencies: + if '@' in dependency: + name, url = dependency.split(' @ ') + min_versions[name] = f'{name} @ {url}' + continue - return is_valid + req = Requirement(dependency) + if ';' in dependency: + marker = req.marker + if marker and not marker.evaluate({'python_version': python_version}): + continue # Skip this dependency if the marker does not apply to the current Python version + + if req.name not in min_versions: + min_version = next((spec.version for spec in req.specifier if spec.operator in ('>=', '==')), None) + if min_version: + min_versions[req.name] = f'{req.name}=={min_version}' + + elif '@' not in min_versions[req.name]: + existing_version = Version(min_versions[req.name].split('==')[1]) + new_version = next((spec.version for spec in req.specifier if spec.operator in ('>=', '==')), existing_version) + if new_version > existing_version: + min_versions[req.name] = f'{req.name}=={new_version}' # Change when a valid newer version is found + + return list(min_versions.values()) @task def install_minimum(c): - with open('setup.py', 'r') as setup_py: - lines = setup_py.read().splitlines() - - versions = [] - started = False - for line in lines: - if started: - if line == ']': - break - - line = line.strip() - if _validate_python_version(line): - requirement = re.match(r'[^>]*', line).group(0) - requirement = re.sub(r"""['",]""", '', requirement) - version = re.search(r'>=?(\d\.?)+', line).group(0) - if version: - version = re.sub(r'>=?', '==', version) - version = re.sub(r"""['",]""", '', version) - requirement += version - - versions.append(requirement) - - elif line.startswith('install_requires = ['): - started = True - - c.run(f'python -m pip install {" ".join(versions)}') + with open('pyproject.toml', 'rb') as pyproject_file: + pyproject_data = tomli.load(pyproject_file) + + dependencies = pyproject_data.get('project', {}).get('dependencies', []) + python_version = '.'.join(map(str, sys.version_info[:2])) + minimum_versions = _get_minimum_versions(dependencies, python_version) + + if minimum_versions: + c.run(f'python -m pip install {" ".join(minimum_versions)}') @task diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..28045a5d --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""CTGAN tests.""" diff --git a/tests/integration/synthesizer/test_tvae.py b/tests/integration/synthesizer/test_tvae.py index 5e50c4ea..1cf93774 100644 --- a/tests/integration/synthesizer/test_tvae.py +++ b/tests/integration/synthesizer/test_tvae.py @@ -16,14 +16,17 @@ from ctgan.synthesizers.tvae import TVAE -def test_tvae(tmpdir): +def test_tvae(tmpdir, capsys): """Test the TVAE load/save methods.""" + # Setup iris = datasets.load_iris() data = pd.DataFrame(iris.data, columns=iris.feature_names) data['class'] = pd.Series(iris.target).map(iris.target_names.__getitem__) + tvae = TVAE(epochs=10, verbose=True) - tvae = TVAE(epochs=10) + # Run tvae.fit(data, ['class']) + captured_out = capsys.readouterr().err path = str(tmpdir / 'test_tvae.pkl') tvae.save(path) @@ -31,10 +34,17 @@ def test_tvae(tmpdir): sampled = tvae.sample(100) + # Assert assert sampled.shape == (100, 5) assert isinstance(sampled, pd.DataFrame) assert set(sampled.columns) == set(data.columns) assert set(sampled.dtypes) == set(data.dtypes) + loss_values = tvae.loss_values + assert len(loss_values) == 10 + assert set(loss_values.columns) == {'Epoch', 'Batch', 'Loss'} + assert all(loss_values['Batch'] == 0) + last_loss_val = loss_values['Loss'].iloc[-1] + assert f'Loss: {round(last_loss_val, 3):.3f}: 100%' in captured_out def test_drop_last_false(): diff --git a/tests/test_tasks.py b/tests/test_tasks.py new file mode 100644 index 00000000..c78986cf --- /dev/null +++ b/tests/test_tasks.py @@ -0,0 +1,38 @@ +"""Tests for the ``tasks.py`` file.""" +from tasks import _get_minimum_versions + + +def test_get_minimum_versions(): + """Test the ``_get_minimum_versions`` method. + + The method should return the minimum versions of the dependencies for the given python version. + If a library is linked to an URL, the minimum version should be the URL. + """ + # Setup + dependencies = [ + "numpy>=1.20.0,<2;python_version<'3.10'", + "numpy>=1.23.3,<2;python_version>='3.10'", + "pandas>=1.2.0,<2;python_version<'3.10'", + "pandas>=1.3.0,<2;python_version>='3.10'", + 'humanfriendly>=8.2,<11', + 'pandas @ git+https://github.com/pandas-dev/pandas.git@master#egg=pandas' + ] + + # Run + minimum_versions_39 = _get_minimum_versions(dependencies, '3.9') + minimum_versions_310 = _get_minimum_versions(dependencies, '3.10') + + # Assert + expected_versions_39 = [ + 'numpy==1.20.0', + 'pandas @ git+https://github.com/pandas-dev/pandas.git@master#egg=pandas', + 'humanfriendly==8.2', + ] + expected_versions_310 = [ + 'numpy==1.23.3', + 'pandas @ git+https://github.com/pandas-dev/pandas.git@master#egg=pandas', + 'humanfriendly==8.2', + ] + + assert minimum_versions_39 == expected_versions_39 + assert minimum_versions_310 == expected_versions_310 diff --git a/tests/unit/synthesizer/test_tvae.py b/tests/unit/synthesizer/test_tvae.py new file mode 100644 index 00000000..98ee862f --- /dev/null +++ b/tests/unit/synthesizer/test_tvae.py @@ -0,0 +1,47 @@ +"""TVAE unit testing module.""" + +from unittest.mock import MagicMock, Mock, call, patch + +import pandas as pd + +from ctgan.synthesizers import TVAE + + +class TestTVAE: + @patch('ctgan.synthesizers.tvae._loss_function') + @patch('ctgan.synthesizers.tvae.tqdm') + def test_fit_verbose(self, tqdm_mock, loss_func_mock): + """Test verbose parameter prints progress bar.""" + # Setup + epochs = 1 + + def mock_iter(): + for i in range(epochs): + yield i + + def mock_add(a, b): + mock_loss = Mock() + mock_loss.detach().cpu().item.return_value = 1.23456789 + return mock_loss + + loss_mock = MagicMock() + loss_mock.__add__ = mock_add + loss_func_mock.return_value = (loss_mock, loss_mock) + + iterator_mock = MagicMock() + iterator_mock.__iter__.side_effect = mock_iter + tqdm_mock.return_value = iterator_mock + synth = TVAE(epochs=epochs, verbose=True) + train_data = pd.DataFrame({ + 'col1': [0, 1, 2, 3, 4], + 'col2': [10, 11, 12, 13, 14] + }) + + # Run + synth.fit(train_data) + + # Assert + tqdm_mock.assert_called_once_with(range(epochs), disable=False) + assert iterator_mock.set_description.call_args_list[0] == call('Loss: 0.000') + assert iterator_mock.set_description.call_args_list[1] == call('Loss: 1.235') + assert iterator_mock.set_description.call_count == 2