Skip to content

Commit

Permalink
queue as context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexPiche committed Oct 29, 2024
1 parent 46289f8 commit 249bbde
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 32 deletions.
6 changes: 3 additions & 3 deletions examples/rl_gsm8k/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@ python examples/rl_gsm8k/orchestrate_rl.py

![image](https://github.com/user-attachments/assets/c715de7a-8d15-4504-9c7c-d8ad28726941)

### Collect data
### Collect online RL training data

#### Collect tapes
* the current model (updated llama 3.1 8b) is served on all the gpus using vllm.
* a subset of 16 tasks from the train set of gsm8k is sampled and replicated 64 times each for a total of 1024 tasks.
* the agent produce complete tapes for each of these 1024 tasks using temperature 0.7.
* traces are created from these new tapes.
* the log prob of the traces under the current model are computed. We refer to these as log prob to be close to the naming of the grpo paper. #todo should be log prob and old log prob only in fine tune.
* the log prob of the traces under the current model are computed.

#### Annotate tapes with rewards
* For each trace, the reward is computed as follows:
* +1 for correct answer
* 0 for incorrect answer or no answer
* -1 for step that cannot be parsed
* -1 for step that cannot be parsed to json


#### Annotate tapes with ref log probs
Expand Down
17 changes: 9 additions & 8 deletions examples/rl_gsm8k/orchestrate_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from tapeagents.finetune.logging_ import flatten_dict_config, init_wandb
from tapeagents.io import save_json_tape
from tapeagents.llms import TrainableLLM
from tapeagents.observe import erase_sqlite, retrieve_all_llm_calls, start_sqlite_queue_writer, stop_sqlite_queue_writer
from tapeagents.observe import SQLiteQueueManager, retrieve_all_llm_calls

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -112,7 +112,13 @@ def generate_training_data(

logger.info("Starting main loop")
start_sampling_from_llm = time.time()
new_tapes = list(batch_main_loop(agent, tapes, env, max_loops=cfg.max_loops, n_workers=cfg.n_workers))

with SQLiteQueueManager() as queue_manager:
new_tapes = list(batch_main_loop(agent, tapes, env, max_loops=cfg.max_loops, n_workers=cfg.n_workers))
while not queue_manager.is_empty:
logging.info("Waiting for LLM calls to be written to SQLite")
time.sleep(5)

end_sampling_from_llm = time.time()
start_reading_sqlite = time.time()
if dataset_name == "train":
Expand Down Expand Up @@ -436,12 +442,7 @@ def main(cfg: DictConfig):
)
state["iteration"] += 1
save_state(state, state_path)
erase_sqlite()


if __name__ == "__main__":
try:
start_sqlite_queue_writer()
main()
finally:
stop_sqlite_queue_writer()
main()
107 changes: 86 additions & 21 deletions tapeagents/observe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import queue
import sqlite3
import threading
from typing import Callable, Type
import time
from typing import Callable, Optional, Type

from pydantic import BaseModel

Expand All @@ -14,8 +15,7 @@
logger = logging.getLogger(__name__)

_checked_sqlite = False
LLM_WRITE_QUEUE = None
_writer_thread = None
_ACTIVE_MANAGER: Optional['SQLiteQueueManager'] = None

LLMCallListener = Callable[[LLMCall], None]
TapeListener = Callable[[Tape], None]
Expand Down Expand Up @@ -113,12 +113,20 @@ def sqlite_writer(call):
logger.error(f"Failed to store LLMCall: {e}")



def sqlite_store_llm_call(call: LLMCall):
global LLM_WRITE_QUEUE
if LLM_WRITE_QUEUE is not None:
LLM_WRITE_QUEUE.put(call)
"""Standalone function to store LLM calls.
Will use the queue if available (within context manager),
otherwise falls back to single-threaded mode.
"""
if _ACTIVE_MANAGER is not None and _ACTIVE_MANAGER.queue is not None:
# We're in a context manager, use the queue
logger.debug("Using SQLite queue writing mode")
_ACTIVE_MANAGER.queue.put(call)
else:
logger.warning("writing would be single-threaded and blocking unless you start the queue")
# We're not in a context manager, use single-threaded mode
logger.debug("Using single-threaded SQLite writing mode")
sqlite_writer(call)


Expand Down Expand Up @@ -250,19 +258,76 @@ def dict_factory(cursor, row):
return calls


def start_sqlite_queue_writer():
global LLM_WRITE_QUEUE, _writer_thread
if LLM_WRITE_QUEUE is not None:
return # Already running
LLM_WRITE_QUEUE = queue.Queue()
_writer_thread = threading.Thread(target=queue_sqlite_writer, daemon=True)
_writer_thread.start()
class SQLiteQueueManager:
def __init__(self):
self.write_queue: Optional[queue.Queue] = None
self.writer_thread: Optional[threading.Thread] = None

def __enter__(self):
"""Start the SQLite queue writer when entering the context."""
if self.write_queue is not None:
return self # Already running

self.write_queue = queue.Queue()
self.writer_thread = threading.Thread(
target=self._queue_sqlite_writer,
daemon=True
)
self.writer_thread.start()

# Set the global reference
global _ACTIVE_MANAGER
_ACTIVE_MANAGER = self
return self

def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop the SQLite queue writer when exiting the context."""
if self.write_queue is not None and self.writer_thread is not None:
self.wait_for_empty()
self.write_queue.put(None) # Signal thread to stop
self.writer_thread.join() # Wait for thread to finish
self.write_queue = None
self.writer_thread = None

# Clear the global reference
global _ACTIVE_MANAGER
_ACTIVE_MANAGER = None

def _queue_sqlite_writer(self):
"""The worker function that processes the queue."""
while True:
item = self.write_queue.get()
if item is None: # Stop signal
break
sqlite_writer(item)
self.write_queue.task_done()

def wait_for_empty(self, timeout: Optional[float] = None) -> bool:
"""Wait for the queue to be empty and all tasks to be processed."""
if self.write_queue is None:
return True

def stop_sqlite_queue_writer():
global LLM_WRITE_QUEUE, _writer_thread
if LLM_WRITE_QUEUE is not None and _writer_thread is not None:
LLM_WRITE_QUEUE.put(None) # Signal thread to stop
_writer_thread.join() # Wait for thread to finish
LLM_WRITE_QUEUE = None
_writer_thread = None
try:
self.write_queue.join()
start_time = time.monotonic()
while not self.write_queue.empty():
if timeout is not None:
elapsed = time.monotonic() - start_time
if elapsed >= timeout:
return False
time.sleep(0.1)
self.write_queue.join()
return True
except Exception as e:
logger.error(f"Error while waiting for queue to empty: {e}")
return False

@property
def queue(self) -> Optional[queue.Queue]:
"""Access the write queue."""
return self.write_queue

@property
def is_empty(self) -> bool:
"""Check if the queue is empty."""
return self.write_queue is None or self.write_queue.empty()

0 comments on commit 249bbde

Please sign in to comment.