Skip to content

Commit

Permalink
able to ablate conservative reg
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 30, 2023
1 parent 5b0cc20 commit d81c7b5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
13 changes: 12 additions & 1 deletion q_transformer/q_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def __init__(

self.discount_factor_gamma = discount_factor_gamma
self.n_step_q_learning = n_step_q_learning

self.has_conservative_reg_loss = conservative_reg_loss_weight > 0.
self.conservative_reg_loss_weight = conservative_reg_loss_weight

self.register_buffer('discount_matrix', None, persistent = False)
Expand Down Expand Up @@ -186,6 +188,10 @@ def __init__(
self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
assert self.checkpoint_folder.is_dir()

# dummy loss

self.register_buffer('zero', torch.tensor(0.))

# training step related

self.num_train_steps = num_train_steps
Expand All @@ -209,7 +215,8 @@ def save(
pkg = dict(
model = self.unwrap(self.model).state_dict(),
ema_model = self.unwrap(self.ema_model).state_dict(),
optimizer = self.optimizer.state_dict()
optimizer = self.optimizer.state_dict(),
step = self.step.item()
)

torch.save(pkg, str(path))
Expand All @@ -224,6 +231,7 @@ def load(self, path):
self.unwrap(self.ema_model).load_state_dict(pkg['ema_model'])

self.optimizer.load_state_dict(pkg['optimizer'])
self.step.copy_(pkg['step'])

@property
def device(self):
Expand Down Expand Up @@ -528,6 +536,9 @@ def learn(
td_loss, q_intermediates = self.q_learn(*args, **q_learn_kwargs)
num_timesteps = 1

if not self.has_conservative_reg_loss:
return loss, Losses(td_loss, self.zero)

# calculate conservative regularization
# section 4.2 in paper, eq 2

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'q-transformer',
packages = find_packages(exclude=[]),
version = '0.0.31',
version = '0.0.32',
license='MIT',
description = 'Q-Transformer',
author = 'Phil Wang',
Expand Down

0 comments on commit d81c7b5

Please sign in to comment.