Skip to content

Commit

Permalink
run ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
gsheni committed Jun 21, 2024
1 parent 67d4388 commit f1c0104
Show file tree
Hide file tree
Showing 20 changed files with 421 additions and 382 deletions.
8 changes: 3 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,12 @@ install-develop: clean-build clean-pyc ## install the package in editable mode a
# LINT TARGETS

.PHONY: lint
lint: ## check style with flake8 and isort
lint:
invoke lint

.PHONY: fix-lint
fix-lint: ## fix lint issues using autoflake, autopep8, and isort
find ctgan tests -name '*.py' | xargs autoflake --in-place --remove-all-unused-imports --remove-unused-variables
autopep8 --in-place --recursive --aggressive ctgan tests
isort --apply --atomic --recursive ctgan tests
fix-lint:
invoke fix-lint


# TEST TARGETS
Expand Down
6 changes: 1 addition & 5 deletions ctgan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,4 @@
from ctgan.synthesizers.ctgan import CTGAN
from ctgan.synthesizers.tvae import TVAE

__all__ = (
'CTGAN',
'TVAE',
'load_demo'
)
__all__ = ('CTGAN', 'TVAE', 'load_demo')
131 changes: 82 additions & 49 deletions ctgan/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,78 @@

def _parse_args():
parser = argparse.ArgumentParser(description='CTGAN Command Line Interface')
parser.add_argument('-e', '--epochs', default=300, type=int,
help='Number of training epochs')
parser.add_argument('-t', '--tsv', action='store_true',
help='Load data in TSV format instead of CSV')
parser.add_argument('--no-header', dest='header', action='store_false',
help='The CSV file has no header. Discrete columns will be indices.')
parser.add_argument('-e', '--epochs', default=300, type=int, help='Number of training epochs')
parser.add_argument(
'-t', '--tsv', action='store_true', help='Load data in TSV format instead of CSV'
)
parser.add_argument(
'--no-header',
dest='header',
action='store_false',
help='The CSV file has no header. Discrete columns will be indices.',
)

parser.add_argument('-m', '--metadata', help='Path to the metadata')
parser.add_argument('-d', '--discrete',
help='Comma separated list of discrete columns without whitespaces.')
parser.add_argument('-n', '--num-samples', type=int,
help='Number of rows to sample. Defaults to the training data size')

parser.add_argument('--generator_lr', type=float, default=2e-4,
help='Learning rate for the generator.')
parser.add_argument('--discriminator_lr', type=float, default=2e-4,
help='Learning rate for the discriminator.')

parser.add_argument('--generator_decay', type=float, default=1e-6,
help='Weight decay for the generator.')
parser.add_argument('--discriminator_decay', type=float, default=0,
help='Weight decay for the discriminator.')

parser.add_argument('--embedding_dim', type=int, default=128,
help='Dimension of input z to the generator.')
parser.add_argument('--generator_dim', type=str, default='256,256',
help='Dimension of each generator layer. '
'Comma separated integers with no whitespaces.')
parser.add_argument('--discriminator_dim', type=str, default='256,256',
help='Dimension of each discriminator layer. '
'Comma separated integers with no whitespaces.')

parser.add_argument('--batch_size', type=int, default=500,
help='Batch size. Must be an even number.')
parser.add_argument('--save', default=None, type=str,
help='A filename to save the trained synthesizer.')
parser.add_argument('--load', default=None, type=str,
help='A filename to load a trained synthesizer.')

parser.add_argument('--sample_condition_column', default=None, type=str,
help='Select a discrete column name.')
parser.add_argument('--sample_condition_column_value', default=None, type=str,
help='Specify the value of the selected discrete column.')
parser.add_argument(
'-d', '--discrete', help='Comma separated list of discrete columns without whitespaces.'
)
parser.add_argument(
'-n',
'--num-samples',
type=int,
help='Number of rows to sample. Defaults to the training data size',
)

parser.add_argument(
'--generator_lr', type=float, default=2e-4, help='Learning rate for the generator.'
)
parser.add_argument(
'--discriminator_lr', type=float, default=2e-4, help='Learning rate for the discriminator.'
)

parser.add_argument(
'--generator_decay', type=float, default=1e-6, help='Weight decay for the generator.'
)
parser.add_argument(
'--discriminator_decay', type=float, default=0, help='Weight decay for the discriminator.'
)

parser.add_argument(
'--embedding_dim', type=int, default=128, help='Dimension of input z to the generator.'
)
parser.add_argument(
'--generator_dim',
type=str,
default='256,256',
help='Dimension of each generator layer. ' 'Comma separated integers with no whitespaces.',
)
parser.add_argument(
'--discriminator_dim',
type=str,
default='256,256',
help='Dimension of each discriminator layer. '
'Comma separated integers with no whitespaces.',
)

parser.add_argument(
'--batch_size', type=int, default=500, help='Batch size. Must be an even number.'
)
parser.add_argument(
'--save', default=None, type=str, help='A filename to save the trained synthesizer.'
)
parser.add_argument(
'--load', default=None, type=str, help='A filename to load a trained synthesizer.'
)

parser.add_argument(
'--sample_condition_column', default=None, type=str, help='Select a discrete column name.'
)
parser.add_argument(
'--sample_condition_column_value',
default=None,
type=str,
help='Specify the value of the selected discrete column.',
)

parser.add_argument('data', help='Path to training data')
parser.add_argument('output', help='Path of the output file')
Expand All @@ -72,11 +101,16 @@ def main():
generator_dim = [int(x) for x in args.generator_dim.split(',')]
discriminator_dim = [int(x) for x in args.discriminator_dim.split(',')]
model = CTGAN(
embedding_dim=args.embedding_dim, generator_dim=generator_dim,
discriminator_dim=discriminator_dim, generator_lr=args.generator_lr,
generator_decay=args.generator_decay, discriminator_lr=args.discriminator_lr,
discriminator_decay=args.discriminator_decay, batch_size=args.batch_size,
epochs=args.epochs)
embedding_dim=args.embedding_dim,
generator_dim=generator_dim,
discriminator_dim=discriminator_dim,
generator_lr=args.generator_lr,
generator_decay=args.generator_decay,
discriminator_lr=args.discriminator_lr,
discriminator_decay=args.discriminator_decay,
batch_size=args.batch_size,
epochs=args.epochs,
)
model.fit(data, discrete_columns)

if args.save is not None:
Expand All @@ -88,9 +122,8 @@ def main():
assert args.sample_condition_column_value is not None

sampled = model.sample(
num_samples,
args.sample_condition_column,
args.sample_condition_column_value)
num_samples, args.sample_condition_column, args.sample_condition_column_value
)

if args.tsv:
write_tsv(sampled, args.metadata, args.output)
Expand Down
12 changes: 3 additions & 9 deletions ctgan/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ def read_csv(csv_filename, meta_filename=None, header=True, discrete=None):
metadata = json.load(meta_file)

discrete_columns = [
column['name']
for column in metadata['columns']
if column['type'] != 'continuous'
column['name'] for column in metadata['columns'] if column['type'] != 'continuous'
]

elif discrete:
Expand All @@ -36,10 +34,7 @@ def read_tsv(data_filename, meta_filename):
with open(meta_filename) as f:
column_info = f.readlines()

column_info_raw = [
x.replace('{', ' ').replace('}', ' ').split()
for x in column_info
]
column_info_raw = [x.replace('{', ' ').replace('}', ' ').split() for x in column_info]

discrete = []
continuous = []
Expand All @@ -57,7 +52,7 @@ def read_tsv(data_filename, meta_filename):
meta = {
'continuous_columns': continuous,
'discrete_columns': discrete,
'column_info': column_info
'column_info': column_info,
}

with open(data_filename) as f:
Expand All @@ -82,7 +77,6 @@ def read_tsv(data_filename, meta_filename):
def write_tsv(data, meta, output_filename):
"""Write to a tsv file."""
with open(output_filename, 'w') as f:

for row in data:
for idx, col in enumerate(row):
if idx in meta['continuous_columns']:
Expand Down
32 changes: 14 additions & 18 deletions ctgan/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@ def __init__(self, data, output_info, log_frequency):
self._data_length = len(data)

def is_discrete_column(column_info):
return (len(column_info) == 1
and column_info[0].activation_fn == 'softmax')
return len(column_info) == 1 and column_info[0].activation_fn == 'softmax'

n_discrete_columns = sum(
[1 for column_info in output_info if is_discrete_column(column_info)])
n_discrete_columns = sum([
1 for column_info in output_info if is_discrete_column(column_info)
])

self._discrete_column_matrix_st = np.zeros(
n_discrete_columns, dtype='int32')
self._discrete_column_matrix_st = np.zeros(n_discrete_columns, dtype='int32')

# Store the row id for each category in each discrete column.
# For example _rid_by_cat_cols[a][b] is a list of all rows with the
Expand All @@ -41,20 +40,17 @@ def is_discrete_column(column_info):
assert st == data.shape[1]

# Prepare an interval matrix for efficiently sample conditional vector
max_category = max([
column_info[0].dim
for column_info in output_info
if is_discrete_column(column_info)
], default=0)
max_category = max(
[column_info[0].dim for column_info in output_info if is_discrete_column(column_info)],
default=0,
)

self._discrete_column_cond_st = np.zeros(n_discrete_columns, dtype='int32')
self._discrete_column_n_category = np.zeros(n_discrete_columns, dtype='int32')
self._discrete_column_category_prob = np.zeros((n_discrete_columns, max_category))
self._n_discrete_columns = n_discrete_columns
self._n_categories = sum([
column_info[0].dim
for column_info in output_info
if is_discrete_column(column_info)
column_info[0].dim for column_info in output_info if is_discrete_column(column_info)
])

st = 0
Expand All @@ -68,7 +64,7 @@ def is_discrete_column(column_info):
if log_frequency:
category_freq = np.log(category_freq + 1)
category_prob = category_freq / np.sum(category_freq)
self._discrete_column_category_prob[current_id, :span_info.dim] = category_prob
self._discrete_column_category_prob[current_id, : span_info.dim] = category_prob
self._discrete_column_cond_st[current_id] = current_cond_st
self._discrete_column_n_category[current_id] = span_info.dim
current_cond_st += span_info.dim
Expand Down Expand Up @@ -98,14 +94,13 @@ def sample_condvec(self, batch):
if self._n_discrete_columns == 0:
return None

discrete_column_id = np.random.choice(
np.arange(self._n_discrete_columns), batch)
discrete_column_id = np.random.choice(np.arange(self._n_discrete_columns), batch)

cond = np.zeros((batch, self._n_categories), dtype='float32')
mask = np.zeros((batch, self._n_discrete_columns), dtype='float32')
mask[np.arange(batch), discrete_column_id] = 1
category_id_in_col = self._random_choice_prob_index(discrete_column_id)
category_id = (self._discrete_column_cond_st[discrete_column_id] + category_id_in_col)
category_id = self._discrete_column_cond_st[discrete_column_id] + category_id_in_col
cond[np.arange(batch), category_id] = 1

return cond, mask, discrete_column_id, category_id_in_col
Expand All @@ -130,6 +125,7 @@ def sample_data(self, data, n, col, opt):
Args:
data:
The training data.
Returns:
n:
n rows of matrix data.
Expand Down
Loading

0 comments on commit f1c0104

Please sign in to comment.