From a521b02d643ed0062fe9d8bf21745fc6b15ebd98 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 3 Feb 2024 07:30:05 -0800 Subject: [PATCH] address https://github.com/lucidrains/q-transformer/issues/10 --- q_transformer/agent.py | 21 +++++++++++++++------ setup.py | 2 +- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/q_transformer/agent.py b/q_transformer/agent.py index 3d901b5..64acbc0 100644 --- a/q_transformer/agent.py +++ b/q_transformer/agent.py @@ -206,7 +206,6 @@ def __init__( if condition_on_text: text_embeds_path = mem_path / TEXT_EMBEDS_FILENAME text_embed_shape = environment.text_embed_shape - self.text_embed_shape = text_embed_shape self.text_embeds = open_memmap(str(text_embeds_path), dtype = 'float32', mode = 'w+', shape = (*prec_shape, *text_embed_shape)) @@ -269,11 +268,21 @@ def forward(self): curr_state = next_state if self.condition_on_text: - del self.text_embeds + self.text_embeds.flush() + + self.states.flush() + self.actions.flush() + self.rewards.flush() + self.dones.flush() + + # close memmap + + if self.condition_on_text: + del self.text_embeds - del self.states - del self.actions - del self.rewards - del self.dones + del self.states + del self.actions + del self.rewards + del self.dones print(f'completed, memories stored to {self.memories_dataset_folder.resolve()}') diff --git a/setup.py b/setup.py index d11c414..da1e771 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'q-transformer', packages = find_packages(exclude=[]), - version = '0.1.12', + version = '0.1.14', license='MIT', description = 'Q-Transformer', author = 'Phil Wang',