From bdf831d748effecdd91d92390af72768ef92222e Mon Sep 17 00:00:00 2001 From: Michael Fuest Date: Fri, 20 Dec 2024 07:44:44 -0500 Subject: [PATCH] cond module updates --- config/config.yaml | 2 +- config/model/acgan.yaml | 6 +++--- eval/evaluator.py | 2 +- generator/conditioning.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 76c7899..7f605aa 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -4,7 +4,7 @@ defaults: - evaluator: default - _self_ -device: 0 # 0, cpu +device: 1 # 0, cpu job_name: ${model.name}_${dataset.name}_${dataset.user_group} run_dir: outputs/${job_name}/${now:%Y-%m-%d_%H-%M-%S} diff --git a/config/model/acgan.yaml b/config/model/acgan.yaml index 437bdb9..2ce6edf 100644 --- a/config/model/acgan.yaml +++ b/config/model/acgan.yaml @@ -3,11 +3,11 @@ name: acgan noise_dim: 256 cond_emb_dim: 64 sparse_conditioning_loss_weight: 0.5 -kl_weight: 0.1 +kl_weight: 0.5 -batch_size: 32 +batch_size: 64 n_epochs: 200 -lr_gen: 3e-4 +lr_gen: 2e-4 lr_discr: 1e-4 warm_up_epochs: 100 include_auxiliary_losses: True diff --git a/eval/evaluator.py b/eval/evaluator.py index b0dbc46..5164326 100644 --- a/eval/evaluator.py +++ b/eval/evaluator.py @@ -263,7 +263,7 @@ def create_visualizations( dataset: Any, model: Any, num_samples: int = 100, - num_runs: int = 10, + num_runs: int = 3, ): """ Create various visualizations for the evaluation results. diff --git a/generator/conditioning.py b/generator/conditioning.py index 7828881..f15aa19 100644 --- a/generator/conditioning.py +++ b/generator/conditioning.py @@ -3,7 +3,7 @@ class ConditioningModule(nn.Module): - def __init__(self, categorical_dims, embedding_dim, device, alpha=0.8): + def __init__(self, categorical_dims, embedding_dim, device, alpha=0.1): super(ConditioningModule, self).__init__() self.embedding_dim = embedding_dim self.device = device