Skip to content

Commit

Permalink
make release-tag: Merge branch 'main' into stable
Browse files Browse the repository at this point in the history
  • Loading branch information
frances-h committed Feb 13, 2024
2 parents dbe8cd2 + e228a54 commit aef4e07
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 19 deletions.
8 changes: 8 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# History

## v0.9.0 - 2024-02-13

This release makes CTGAN sampling more efficient by saving the frequency of each categorical value.

### New Features

* Improve DataSampler efficiency - Issue [#327] ((https://github.com/sdv-dev/CTGAN/issue/327)) by @fealho

## v0.8.0 - 2023-11-13

This release adds a progress bar that will show when setting the `verbose` parameter to `True`
Expand Down
2 changes: 1 addition & 1 deletion ctgan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

__author__ = 'DataCebo, Inc.'
__email__ = '[email protected]'
__version__ = '0.8.0'
__version__ = '0.9.0.dev1'

from ctgan.demo import load_demo
from ctgan.synthesizers.ctgan import CTGAN
Expand Down
29 changes: 15 additions & 14 deletions ctgan/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class DataSampler(object):
"""DataSampler samples the conditional vector and corresponding data for CTGAN."""

def __init__(self, data, output_info, log_frequency):
self._data = data
self._data_length = len(data)

def is_discrete_column(column_info):
return (len(column_info) == 1
Expand Down Expand Up @@ -115,33 +115,34 @@ def sample_original_condvec(self, batch):
if self._n_discrete_columns == 0:
return None

category_freq = self._discrete_column_category_prob.flatten()
category_freq = category_freq[category_freq != 0]
category_freq = category_freq / np.sum(category_freq)
col_idxs = np.random.choice(np.arange(len(category_freq)), batch, p=category_freq)
cond = np.zeros((batch, self._n_categories), dtype='float32')

for i in range(batch):
row_idx = np.random.randint(0, len(self._data))
col_idx = np.random.randint(0, self._n_discrete_columns)
matrix_st = self._discrete_column_matrix_st[col_idx]
matrix_ed = matrix_st + self._discrete_column_n_category[col_idx]
pick = np.argmax(self._data[row_idx, matrix_st:matrix_ed])
cond[i, pick + self._discrete_column_cond_st[col_idx]] = 1
cond[np.arange(batch), col_idxs] = 1

return cond

def sample_data(self, n, col, opt):
def sample_data(self, data, n, col, opt):
"""Sample data from original training data satisfying the sampled conditional vector.
Args:
data:
The training data.
Returns:
n rows of matrix data.
n:
n rows of matrix data.
"""
if col is None:
idx = np.random.randint(len(self._data), size=n)
return self._data[idx]
idx = np.random.randint(len(data), size=n)
return data[idx]

idx = []
for c, o in zip(col, opt):
idx.append(np.random.choice(self._rid_by_cat_cols[c][o]))

return self._data[idx]
return data[idx]

def dim_cond_vec(self):
"""Return the total number of categories."""
Expand Down
5 changes: 3 additions & 2 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,8 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
condvec = self._data_sampler.sample_condvec(self._batch_size)
if condvec is None:
c1, m1, col, opt = None, None, None, None
real = self._data_sampler.sample_data(self._batch_size, col, opt)
real = self._data_sampler.sample_data(
train_data, self._batch_size, col, opt)
else:
c1, m1, col, opt = condvec
c1 = torch.from_numpy(c1).to(self._device)
Expand All @@ -365,7 +366,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
perm = np.arange(self._batch_size)
np.random.shuffle(perm)
real = self._data_sampler.sample_data(
self._batch_size, col[perm], opt[perm])
train_data, self._batch_size, col[perm], opt[perm])
c2 = c1[perm]

fake = self._generator(fakez)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.8.0
current_version = 0.9.0.dev1
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<candidate>\d+))?
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,6 @@
test_suite='tests',
tests_require=tests_require,
url='https://github.com/sdv-dev/CTGAN',
version='0.8.0',
version='0.9.0.dev1',
zip_safe=False,
)

0 comments on commit aef4e07

Please sign in to comment.