From d81c7b51867be26a7fc7d39e6486f986ef568307 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 30 Nov 2023 11:19:36 -0800 Subject: [PATCH] able to ablate conservative reg --- q_transformer/q_learner.py | 13 ++++++++++++- setup.py | 2 +- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/q_transformer/q_learner.py b/q_transformer/q_learner.py index b6b00c8..4c9609b 100644 --- a/q_transformer/q_learner.py +++ b/q_transformer/q_learner.py @@ -120,6 +120,8 @@ def __init__( self.discount_factor_gamma = discount_factor_gamma self.n_step_q_learning = n_step_q_learning + + self.has_conservative_reg_loss = conservative_reg_loss_weight > 0. self.conservative_reg_loss_weight = conservative_reg_loss_weight self.register_buffer('discount_matrix', None, persistent = False) @@ -186,6 +188,10 @@ def __init__( self.checkpoint_folder.mkdir(exist_ok = True, parents = True) assert self.checkpoint_folder.is_dir() + # dummy loss + + self.register_buffer('zero', torch.tensor(0.)) + # training step related self.num_train_steps = num_train_steps @@ -209,7 +215,8 @@ def save( pkg = dict( model = self.unwrap(self.model).state_dict(), ema_model = self.unwrap(self.ema_model).state_dict(), - optimizer = self.optimizer.state_dict() + optimizer = self.optimizer.state_dict(), + step = self.step.item() ) torch.save(pkg, str(path)) @@ -224,6 +231,7 @@ def load(self, path): self.unwrap(self.ema_model).load_state_dict(pkg['ema_model']) self.optimizer.load_state_dict(pkg['optimizer']) + self.step.copy_(pkg['step']) @property def device(self): @@ -528,6 +536,9 @@ def learn( td_loss, q_intermediates = self.q_learn(*args, **q_learn_kwargs) num_timesteps = 1 + if not self.has_conservative_reg_loss: + return loss, Losses(td_loss, self.zero) + # calculate conservative regularization # section 4.2 in paper, eq 2 diff --git a/setup.py b/setup.py index 411b21b..1b1615a 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'q-transformer', packages = find_packages(exclude=[]), - version = '0.0.31', + version = '0.0.32', license='MIT', description = 'Q-Transformer', author = 'Phil Wang',