diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..582c0ad --- /dev/null +++ b/.gitignore @@ -0,0 +1,142 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version +.idea + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# MacOS DS_Store +.DS_Store + +# Pickle folder +.pkl_memoize_py3 + +# Folder where optimized models are stored +optimized_model + +# Config file for tests coverage +.coveragerc diff --git a/README.md b/README.md index b2e1547..3f1e09d 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,6 @@ -# **Open source implementation for LLaMA-based ChatGPT. 15x faster training process than ChatGPT (wip)** +# ChatLLaMA + +> šŸ“¢ Open source implementation for LLaMA-based ChatGPT runnable in a single GPU. 15x faster training process than `ChatGPT` Meta has recently released LLaMA, a collection of foundational large language models ranging from 7 to 65 billion parameters. LLaMA is creating a lot of excitement because it is smaller than GPT-3 but has better performance. For example, LLaMA's 13B architecture outperforms GPT-3 despite being 10 times smaller. This new collection of fundamental models opens the door to faster inference performance and chatGPT-like real-time assistants, while being cost-effective and running on a single GPU. @@ -12,14 +14,19 @@ The good news is that we introduce `ChatLLaMA`, the first open source implementa - ChatLLaMA has built-in support for DeepSpeed ZERO to speedup the fine-tuning process. - The library also supports all LLaMA model architectures (7B, 13B, 33B, 65B), so that you can fine-tune the model according to your preferences for training time and inference performance. -If you like the project, please show your support by [leaving a star ā­](https://github.com/nebuly-ai/nebullvm/stargazers). - Screen Shot 2023-02-26 at 10 56 13 PM Image from [OpenAIā€™s blog](https://openai.com/blog/chatgpt). +# Installation + +``` +pip install chatllama +``` + + # Get started with ChatLLaMA > :warning: Please note this code represents the algorithmic implementation for RLHF training process of LLaMA and does not contain the model weights. To access the model weights, you need to apply to Meta's [form](https://forms.gle/jk851eBVbX1m5TAv5). diff --git a/chatllama/__init__.py b/chatllama/__init__.py new file mode 100644 index 0000000..ffcc925 --- /dev/null +++ b/chatllama/__init__.py @@ -0,0 +1 @@ +__version__ = '0.0.3' diff --git a/chatllama/langchain_modules/__init__.py b/chatllama/langchain_modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chatllama/langchain_modules/prompt_templates.py b/chatllama/langchain_modules/prompt_templates.py new file mode 100644 index 0000000..eba679d --- /dev/null +++ b/chatllama/langchain_modules/prompt_templates.py @@ -0,0 +1,62 @@ +REWARD_TEMPLATE = dict( + template=( + "Lets pretend that you are a lawyer and you have to" + "evalaute the following completion task from a given" + "assigment with a score between 0 and 5 where 0 represents" + "a bad assignment completion and 5 a perfect completion.\n" + "You MUST evaluate: text quality, content quality and" + "coherence.\n" + "You MUST return only the number that represents your" + "judgment.\n" + "The assignement is:\n{user_input}\n" + "The completion is:\n{completion}\n" + ), + input_variables=["user_input", "completion"], +) + + +AI_CHATBOT_TEMPLATE = dict( + template=( + "Assistant is a large language model trained by Meta and Nebuly.ai\n" + "Assistant is designed to be able to assist with a wide range of " + "tasks, from answering simple questions to providing in-depth " + "explanations and discussions on a wide range of topics. As a " + "language model, Assistant is able to generate human-like text " + "based on the input it receives, allowing it to engage in " + "natural-sounding conversations and provide responses that are " + "coherent and relevant to the topic at hand.\n\n" + "Assistant is constantly learning and improving, and its capabilities " + "are constantly evolving. It is able to process and understand large " + "amounts of text, and can use this knowledge to provide accurate and " + "informative responses to a wide range of questions. Additionally, " + "Assistant is able to generate its own text based on the input it " + "receives, allowing it to engage in discussions and provide " + "explanations and descriptions on a wide range of topics.\n\n" + "Overall, Assistant is a powerful tool that can help with a wide " + "range of tasks and provide valuable insights and information on a " + "wide range of topics. Whether you need help with a specific " + "question or just want to have a conversation about a particular " + "topic, Assistant is here to assist.\n\n{history}\n\n" + "Human: {human_input}\n" + "Assistant:" + ), + input_variables=["history", "human_input"], +) + + +PERSON_CHATBOT_TEMPLATE = dict( + template=( + "You are a human chatting with a chatbot. The chatbot is a large " + "language model trained by Meta and Nebuly-ai\n" + "The chatbot is designed to be able to assist you with a wide range " + "of tasks, from answering simple questions to providing in-depth " + "explanations and discussions on a wide range of topics. You are a " + "human and you are testing the chatbot. Ask the chatbot questions and" + "see how it responds. You can also ask the chatbot to tell you a " + "story." + "\n\n{history}\n\n" + "Chatbot: {chatbot_input}\n" + "Human:" + ), + input_variables=["history", "chatbot_input"], +) diff --git a/chatllama/llama_model.py b/chatllama/llama_model.py new file mode 100644 index 0000000..448e266 --- /dev/null +++ b/chatllama/llama_model.py @@ -0,0 +1,224 @@ +import json +import os +from pathlib import Path +from typing import Tuple, List, Union + +import torch.distributed +import torch.nn as nn +from fairscale.nn.model_parallel.initialize import initialize_model_parallel +from fairscale.nn.model_parallel.layers import ( + ParallelEmbedding, + ColumnParallelLinear, +) +from llama import ModelArgs, Tokenizer +from llama.generation import sample_top_p +from llama.model import TransformerBlock, RMSNorm, precompute_freqs_cis + + +class HFLikeTokenizer: + def __init__(self, tokenizer: Tokenizer): + self.tokenizer = tokenizer + + def __call__(self, texts: Union[List[str], str], *args, **kwargs): + if isinstance(texts, str): + text = self.tokenizer.encode(texts, bos=True, eos=True) + tokens = torch.tensor(text).cuda().long() + else: + texts = [ + self.tokenizer.encode(text, bos=True, eos=True) + for text in texts + ] + max_len = max(len(text) for text in texts) + tokens = ( + torch.full((len(texts), max_len), self.tokenizer.pad_id) + .cuda() + .long() + ) + for i, text in enumerate(texts): + tokens[i, : len(text)] = torch.tensor(text).cuda().long() + output = { + "input_ids": tokens, + "attention_mask": (tokens != self.tokenizer.pad_id).long(), + } + return output + + def decode(self, tokens): + return self.tokenizer.decode(tokens) + + +class Transformer(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + + self.tok_embeddings = ParallelEmbedding( + params.vocab_size, params.dim, init_method=lambda x: x + ) + + self.layers = torch.nn.ModuleList() + for layer_id in range(params.n_layers): + self.layers.append(TransformerBlock(layer_id, params)) + + self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.output = ColumnParallelLinear( + params.dim, params.vocab_size, bias=False, init_method=lambda x: x + ) + + self.freqs_cis = precompute_freqs_cis( + self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 + ) + + def forward(self, tokens: torch.Tensor, attention_mask: torch.Tensor): + start_pos = int(torch.argmax(attention_mask.detach(), dim=-1).item()) + logits = self._forward(tokens, start_pos) + return logits + + def _forward(self, tokens: torch.Tensor, start_pos: int): + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens) + self.freqs_cis = self.freqs_cis.to(h.device) + freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] # noqa E203 + + mask = None + if seqlen > 1: + mask = torch.full( + (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device + ) + mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) + + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h) + output = self.output(h[:, -1, :]) # only compute last logits + return output.float() + + @torch.no_grad() + def generate( + self, + inputs: torch.Tensor, + attention_mask: torch.Tensor, + max_length: int, + temperature: float, + top_p: float = 1.0, + ): + prompt_size = inputs.shape[1] + total_len = min(self.params.max_seq_len, max_length + prompt_size) + start_pos = prompt_size # We assume left padding + prev_pos = 0 + generated_tokens = [] + for cur_pos in range(start_pos, total_len): + logits = self._forward(inputs[:, prev_pos:cur_pos], prev_pos) + if temperature > 0: + probs = torch.softmax(logits / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits, dim=-1) + next_token = next_token.reshape(-1) + generated_tokens.append(next_token) + prev_pos = cur_pos + return torch.stack(generated_tokens, dim=1) + + +def setup_model_parallel() -> Tuple[int, int]: + local_rank = int(os.environ.get("LOCAL_RANK", -1)) + world_size = int(os.environ.get("WORLD_SIZE", -1)) + + torch.distributed.init_process_group("nccl") + initialize_model_parallel(world_size) + torch.cuda.set_device(local_rank) + + # seed must be the same in all processes + torch.manual_seed(1) + return local_rank, world_size + + +def load_checkpoints( + ckpt_dir: str, local_rank: int, world_size: int +) -> Tuple[dict, dict]: + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + assert world_size == len(checkpoints), ( + f"Loading a checkpoint for MP={len(checkpoints)} but world " + f"size is {world_size}" + ) + ckpt_path = checkpoints[local_rank] + print("Loading") + checkpoint = torch.load(ckpt_path, map_location="cpu") + with open(Path(ckpt_dir) / "params.json", "r") as f: + params = json.loads(f.read()) + return checkpoint, params + + +def load_model( + ckpt_dir: str, + tokenizer_path: str, + local_rank: int, + world_size: int, + max_batch_size: int = 32, +) -> Tuple[Transformer, HFLikeTokenizer]: + checkpoint, params = load_checkpoints(ckpt_dir, local_rank, world_size) + model_args: ModelArgs = ModelArgs( + max_seq_len=1024, max_batch_size=max_batch_size, **params + ) + tokenizer = Tokenizer(model_path=tokenizer_path) + model_args.vocab_size = tokenizer.n_words + torch.set_default_tensor_type(torch.cuda.HalfTensor) + model = Transformer(model_args) + torch.set_default_tensor_type(torch.FloatTensor) + model.load_state_dict(checkpoint, strict=False) + tokenizer = HFLikeTokenizer(tokenizer) + return model, tokenizer + + +def generate( + model: Transformer, + tokenizer: Tokenizer, + prompts: List[str], + max_gen_len: int, + temperature: float = 0.8, + top_p: float = 0.95, +) -> List[str]: + bsz = len(prompts) + params = model.params + assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) + + prompt_tokens = [tokenizer.encode(x, bos=True, eos=False) for x in prompts] + + min_prompt_size = min([len(t) for t in prompt_tokens]) + max_prompt_size = max([len(t) for t in prompt_tokens]) + + total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) + + tokens = torch.full((bsz, total_len), tokenizer.pad_id).cuda().long() + for k, t in enumerate(prompt_tokens): + tokens[k, : len(t)] = torch.tensor(t).long() + input_text_mask = tokens != tokenizer.pad_id + start_pos = min_prompt_size + prev_pos = 0 + for cur_pos in range(start_pos, total_len): + logits = model._forward(tokens[:, prev_pos:cur_pos], prev_pos) + if temperature > 0: + probs = torch.softmax(logits / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits, dim=-1) + next_token = next_token.reshape(-1) + # only replace token if prompt has already been generated + next_token = torch.where( + input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token + ) + tokens[:, cur_pos] = next_token + prev_pos = cur_pos + + decoded = [] + for i, t in enumerate(tokens.tolist()): + # cut to max gen len + t = t[: len(prompt_tokens[i]) + max_gen_len] + # cut to eos tok if any + try: + t = t[: t.index(tokenizer.eos_id)] + except ValueError: + pass + decoded.append(tokenizer.decode(t)) + return decoded diff --git a/chatllama/rlhf/__init__.py b/chatllama/rlhf/__init__.py new file mode 100644 index 0000000..433034f --- /dev/null +++ b/chatllama/rlhf/__init__.py @@ -0,0 +1 @@ +"""RLHF implementation inspired to Lucidrains' implementation.""" diff --git a/chatllama/rlhf/actor.py b/chatllama/rlhf/actor.py new file mode 100644 index 0000000..3c8dbb9 --- /dev/null +++ b/chatllama/rlhf/actor.py @@ -0,0 +1,324 @@ +import json +import os + +import torch +from beartype import beartype +from beartype.typing import Optional, Tuple +from einops import rearrange +from torch.utils.data import Dataset, DataLoader +from config import ConfigActor +from utils import TrainingStats + +from chatllama.llama_model import load_model + + +class ActorModel(torch.nn.Module): + """Actor model that generates the augmented prompt from the initial + user_input. The aim is to train this model to generate better prompts. + + Attributes: + model: The model from LLaMA to be used + tokenizer: The LLaMA tokenizer + max_model_tokens (int): Maximum number of tokens that the model can + handle + config (ConfigActor): Configuration for the actor model + + Methods: + load: Load the model from a path + save: Save the model to a path + forward: Compute the action logits for a given sequence. + generate: Generate a sequence from a given prompt + """ + + def __init__(self, config: ConfigActor) -> None: + super().__init__() + # load the model + + self.max_model_tokens = 1024 + self.model, self.tokenizer = load_model( + ckpt_dir=config.model_folder, + tokenizer_path=config.tokenizer_folder, + local_rank=int(os.environ.get("LOCAL_RANK", -1)), + world_size=int(os.environ.get("WORLD_SIZE", -1)), + max_batch_size=config.batch_size, + ) + # save config + self.config = config + + def parameters(self, **kwargs): + """Return the parameters of the model + + Args: + **kwargs: + """ + return self.model.parameters() + + @beartype + def forward( + self, sequences: torch.Tensor, sequences_mask: torch.Tensor + ) -> torch.Tensor: + """Generate logits to have probability distribution over the vocabulary + of the actions + + Args: + sequences (torch.Tensor): Sequences of states and actions used to + compute token logits for the whole list of sequences + attention_mask (torch.Tensor): Mask for the sequences attention + + Returns: + logits (torch.Tensor): Logits for the actions taken + """ + model_output = self.model.forward( + sequences, attention_mask=sequences_mask + ) + if self.config.debug: + print("ActorModel.forward") + print("model_output_logits shape", model_output.logits.shape) + print("model_output logits", model_output.logits) + return model_output.logits + + @beartype + @torch.no_grad() + def generate( + self, states: torch.Tensor, state_mask: torch.Tensor + ) -> Tuple: + """Generate actions and sequences=[states, actions] from state + (i.e. input of the prompt generator model) + + Args: + state (torch.Tensor): the input of the user + state_mask (torch.Tensor): Mask for the state input (for padding) + + Returns: + actions (torch.Tensor): Actions generated from the state + sequences (torch.Tensor): Sequences generated from the + state as [states, actions] + """ + max_sequence = states.shape[1] + max_tokens = self.config.max_tokens + max_sequence + temperature = self.config.temperature + # What if the states + completion are longer than the max context of + # the model? + sequences = self.model.generate( + inputs=states, + attention_mask=state_mask, + max_length=max_tokens, + temperature=temperature, + ) + actions = sequences[:, states.shape[1] :] # noqa E203 + if self.config.debug: + print("ActorModel.generate") + print("state", states) + print("state shape", states.shape) + print("sequence shape", sequences.shape) + print("sequence", sequences) + print("actions shape", actions.shape) + print("actions", actions) + return actions, sequences + + @beartype + def load(self, path: Optional[str] = None) -> None: + """Load the model from the path + + Args: + path (str): Path to the model + """ + if path is None: + path = self.config.model_folder + "/" + self.config.model + ".pt" + if os.path.exists(self.config.model_folder) is False: + os.mkdir(self.config.model_folder) + print( + f"Impossible to load the model: {path}" + f"The path doesn't exist." + ) + return + # load the model + if os.path.exists(path) is False: + print( + f"Impossible to load the model: {path}" + f"The path doesn't exist." + ) + return + model_dict = torch.load(path) + self.model.load_state_dict(model_dict["model"]) + + @beartype + def save(self, path: Optional[str] = None) -> None: + """Save the model to the path + + Args: + path (Optional[str], optional): Path to store the model. + Defaults to None. + """ + if path is None: + path = self.config.model_folder + "/" + self.config.model + ".pt" + if os.path.exists(self.config.model_folder) is False: + os.mkdir(self.config.model_folder) + torch.save({"model": self.model.state_dict()}, path) + + +class ActorDataset(Dataset): + """Dataset for the pretraining of the actor model + read a json file with the following format: + [ + { + "user_input": "..." + "completion": "..." + } , + ... + ] + Where: + user_input: the input of the user + completion: the output of the user + """ + + def __init__(self, path: str, device: torch.device) -> None: + self.device = device + self.path = path + with open(path, "r") as f: + data = json.load(f) + self.data = [ + d["user_input"] + "\n\n###\n\n" + d["completion"] for d in data + ] + self.len = len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + def __len__( + self, + ): + return self.len + + +class ActorTrainer: + """Used to pre-train the actor model to generate better prompts. + + Args: + config (ConfigActor): Configuration for the actor model + + Attributes: + config (ConfigActor): Configuration for the actor model + model (ActorModel): Actor model + loss_function (torch.nn.CrossEntropyLoss): Loss function + optimizer (torch.optim.Adam): Optimizer + validation_flag (bool): Flag to indicate if the validation dataset + is provided + training_stats (TrainingStats): Training statistics + + Methods: + train: Train the actor model + """ + + def __init__(self, config: ConfigActor) -> None: + # load the model + self.config = config + self.model = ActorModel(config) + self.loss_function = torch.nn.CrossEntropyLoss() + self.optimizer = torch.optim.Adam( + self.model.parameters(), lr=config.lr + ) + self.validation_flag = False + self.training_stats = TrainingStats() + if not os.path.exists(config.model_folder): + os.mkdir(config.model_folder) + if config.validation_dataset_path is not None: + self.validation_flag = True + + def train( + self, + ) -> None: + print("Start Actor Model Pretraining") + # get config parameters + train_dataset_path = self.config.train_dataset_path + validation_dataset_path = self.config.validation_dataset_path + batch_size = self.config.batch_size + epochs = self.config.epochs + device = self.config.device + + # create dataloaders + train_dataset = ActorDataset(train_dataset_path, device=device) + train_dataloader = DataLoader(train_dataset, batch_size=batch_size) + if self.validation_flag: + eval_dataset = ActorDataset(validation_dataset_path, device=device) + validation_dataloader = DataLoader( + eval_dataset, batch_size=batch_size + ) + + # compute the number of iterations + n_iter = int(len(train_dataset) / batch_size) + + # traing loop + for epoch in range(epochs): + self.model.train() + for i, input_output in enumerate(train_dataloader): + input_output_tokenized = self.model.tokenizer( + input_output, + return_tensors="pt", + padding=True, + truncation=True, + ) + training_output = input_output_tokenized["input_ids"][:, 1:] + training_input = input_output_tokenized["input_ids"][:, :-1] + attention_mask = input_output_tokenized["attention_mask"][ + :, :-1 + ] + training_output = training_output.to(device) + training_input = training_input.to(device) + attention_mask = attention_mask.to(device) + + # forward pass + est_output = self.model.forward(training_input, attention_mask) + est_output = rearrange(est_output, "b s v -> (b s) v") + training_output = rearrange(training_output, "b s -> (b s)") + loss = self.loss_function(est_output, training_output) + self.training_stats.training_loss.append(loss.item()) + + # backward pass + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # print progress + if i % self.config.iteration_per_print == 0: + print( + f"Epoch: {epoch+1}/{epochs}, " + f"Iteration: {i+1}/{n_iter}, " + f"Training Loss: {loss}" + ) + if self.validation_flag: + self.model.eval() + for i, input_output in enumerate(validation_dataloader): + input_output_tokenized = self.model.tokenizer( + input_output, return_tensors="pt", padding=True + ) + validation_output = input_output_tokenized["input_ids"][ + :, 1: + ] + validation_input = input_output_tokenized["input_ids"][ + :, :-1 + ] + attention_mask = input_output_tokenized["attention_mask"][ + :, :-1 + ] + + # forward pass + est_output = self.model.forward( + validation_input, attention_mask + ) + validation_output = rearrange( + validation_output, "b s -> (b s)" + ) + est_output = rearrange(est_output, "b s v -> (b s) v") + loss = self.loss_function(est_output, validation_output) + self.training_stats.validation_loss.append(loss.item()) + + # print progress + if i % self.config.iteration_per_print == 0: + print( + f"Epoch: {epoch+1}/{epochs}, " + f"Iteration: {i+1}/{n_iter}, " + f"Validation Loss: {loss}" + ) + self.model.save() + print("Training Finished ") diff --git a/chatllama/rlhf/config.py b/chatllama/rlhf/config.py new file mode 100644 index 0000000..bc396a7 --- /dev/null +++ b/chatllama/rlhf/config.py @@ -0,0 +1,230 @@ +import yaml +import os +from dataclasses import dataclass + +import torch +from beartype import beartype +from beartype.typing import Optional + + +@dataclass +class ConfigReward: + """Config parameters for the reward model + + Attributes: + model (str): Model to be used for the reward model + model_folder (str): Path to the folder where model are stored (used + to load / store finetuned model) + device (torch.device): Device to be used for the reward model + model_head_hidden_size (int): Hidden size of the reward model head + debug (bool): enable prints for Debugging + train_dataset_path (Optional[str]): Path to the training dataset. + Default to None. To be specified only for the reward model trainig. + validation_dataset_path (Optional[str]): Path to the validation + dataset. Default to None. To be specified only for the reward + model trainig. + batch_size (Optional[int]): Batch size to train the reward model. + Default to None. To be specified only for the reward model + trainig. + epochs (Optional[int]): Number of epochs to train the reward model. + Default to None. To be specified only for the reward model + trainig. + iteration_per_print (Optional[int]): Number of iterations to print + the training loss. Default to None. To be specified only for the + reward model trainig. + lr (Optional[float]): Learning rate for the reward model. Default to + None. To be specified only for the reward model distillation. + llm_model (Optional[str]): Model to be used for the language model + (LLM). Default to None. + llm_max_tokens (Optional[int]): Max tokens for the LLM. Default to + None. + llm_temperature (Optional[float]): Temperature for the LLM. Default + to None. + """ + + model: str + model_folder: str + device: torch.device + model_head_hidden_size: int + debug: bool + train_dataset_path: Optional[str] = None + validation_dataset_path: Optional[str] = None + batch_size: Optional[int] = None + epochs: Optional[int] = None + iteration_per_print: Optional[int] = None + lr: Optional[float] = None + llm_model: Optional[str] = None + llm_max_tokens: Optional[int] = None + llm_temperature: Optional[float] = None + + +@dataclass +class ConfigActor: + """Config parameters for models + + Attributes: + model (str): Model to be used for the actor + model_folder (str): Path to the folder where model are stored (used + to load / store finetuned model) + max_tokens (int): Max tokens for the actor + temperature (float): Temperature for the actor + device (torch.device): Device to be used for the actor + lr (float): Learning rate for the actor + iteration_per_print (int): Number of iterations to print the + training loss + batch_size (int): Batch size to train the actor + epochs (int): Number of epochs to train the actor + debug (bool): Enable prints for debugging + train_dataset_path (str): Path to the training dataset + validation_dataset_path (Optional[str]): Path to the validation dataset + """ + + model: str + model_folder: str + tokenizer_folder: str + max_tokens: int + temperature: float + device: torch.device + lr: float + iteration_per_print: int + batch_size: int + epochs: int + debug: bool + train_dataset_path: str + validation_dataset_path: Optional[str] = None + + +@dataclass +class ConfigTrainer: + """Config parameters for the trainer, used to configure the reinforcement + learning training loop + + Attributes: + update_timesteps (int): Number of timesteps to update the actor + and critic. Every time update_timesteps timesteps are collected, + the training loop for the actor and critic is executed using the + memory buffer to learn the policy. + temperature (float): Temperature for the actor and critic + max_seq_len (int): Max sequence length for the actor and critic + num_examples (int): Number of examples to generate for the actor + and critic. For each iteration of timestep, num_examples are + sampled from the prompt dataset, processed and stored in the + memory buffer. + actor_lr (float): Learning rate for the actor when training with + reinforcement learning + critic_lr (float): Learning rate for the critic when training with + reinforcement learning + num_episodes (int): Number of episodes, each episodes consist of + a number of timesteps that are used to generate examples + stored in the memory buffer. + max_timesteps (int): Max timesteps for the actor and critic. + for each timestep a set of examples are sampled and used to + generate a completion and a reward. + batch_size (int): Batch size to train the actor and critic. + This batch is used to aggregate the memory from the memory buffer + for the actual training of the actor and critic models. + epochs (int): Number of epochs to train the actor and critic. + actor_eps_clip (float): Epsilon clip for the actor + critic_eps_clip (float): Epsilon clip for the critic + beta_s (float): Beta for the actor and critic + update_checkpoint (int): Number of timesteps to update the checkpoint + llm_model_id (str): Model id for the llm + llm_max_tokens (int): Max tokens for the llm + llm_temperature (float): Temperature for the llm + device (torch.device): Device to be used for the actor and critici + checkpoint_folder (str): Folder to store the checkpoints while training + debug (bool): Enable prints for debugging + """ + + update_timesteps: int + num_examples: int + actor_lr: float + critic_lr: float + num_episodes: int + max_timesteps: int + examples_path: str + batch_size: int + epochs: int + actor_eps_clip: float + critic_eps_clip: float + beta_s: float + update_checkpoint: int + llm_model_id: str + llm_max_tokens: int + llm_temperature: float + device: torch.device + checkpoint_folder: str + debug: bool + + +class Config: + """Store the config parameters for the whole pipeline + + Args: + trainer_dict (Optional[Dict]): Dictionary with the config parameters + for the trainer. Default to None. If None, the config.yaml is + used. + actor_dict (Optional[Dict]): Dictionary with the config parameters + for the actor. Default to None. If None, the config.yaml is + used. + critic_dict (Optional[Dict]): Dictionary with the config parameters + for the critic. Default to None. If None, the config.yaml is + used. + reward_dict (Optional[Dict]): Dictionary with the config parameters + for the reward. Default to None. If None, the config.yaml is + used. + device (Optional[torch.device]): Device to be used for the actor + and critic. Default to None. If None, the device available is + used. + debug (Optional[bool]): Enable prints for debugging. Default to False. + + Attributes: + trainer (ConfigTrainer): Config parameters for the trainer + actor (ConfigActor): Config parameters for the actor + critic (ConfigCritic): Config parameters for the critic + reward (ConfigReward): Config parameters for the reward + """ + + @beartype + def __init__( + self, + path: str, + device: Optional[torch.device] = None, + debug: Optional[bool] = False, + ) -> None: + + # if not specified use the device available + if device is None: + device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + print(f"Current device used:{str(device)}") + + if path is None or os.path.exists(path) is False: + raise ValueError("Path to the config.yaml is not valid") + + # Read the config from yaml + with open(path, "r") as c: + config = yaml.safe_load(c) + + trainer_dict = config["trainer_config"] + actor_dict = config["actor_config"] + critic_dict = config["critic_config"] + reward_dict = config["reward_config"] + + # Trainer Config + trainer_dict["device"] = device + trainer_dict["debug"] = debug + self.trainer = ConfigTrainer(**trainer_dict) + # Actor Config + actor_dict["device"] = device + actor_dict["debug"] = debug + self.actor = ConfigActor(**actor_dict) + # Critic Config + critic_dict["device"] = device + critic_dict["debug"] = debug + self.critic = ConfigReward(**critic_dict) + # Reward Config + reward_dict["device"] = device + reward_dict["debug"] = debug + self.reward = ConfigReward(**reward_dict) diff --git a/chatllama/rlhf/config.yaml b/chatllama/rlhf/config.yaml new file mode 100644 index 0000000..3678483 --- /dev/null +++ b/chatllama/rlhf/config.yaml @@ -0,0 +1,49 @@ +--- +trainer_config: + update_timesteps: 1 + num_examples: 2 + actor_lr: 0.00001 + critic_lr: 0.00001 + num_episodes: 10 + max_timesteps: 10 + examples_path: "dataset/sections_dataset.json" + batch_size: 1 + epochs: 5 + actor_eps_clip: 0.2 + critic_eps_clip: 0.2 + beta_s: 0.1 + update_checkpoint: 10 + llm_model_id: "text-davinci-003" + llm_max_tokens: 1024 + llm_temperature: 0.5 + checkpoint_folder: "./models/checkpoints" + +actor_config: + model: "llama-7B" + max_tokens: 1024 + temperature: 0.9 + train_dataset_path: "dataset/sections_dataset.json" + validation_dataset_path: null + batch_size: 16 + iteration_per_print: 10 + lr: 0.00001 + epochs: 1 + model_folder: "path-to-checkpoints" + +reward_config: + # model to be chosen are gp2-large, bart-base, longformer-base-4096 + model: "longformer-base-4096" + model_head_hidden_size: 2048 + model_folder: "./models" + train_dataset_path: "/home/pierpaolo/Documents/optimapi/dataset/sections_dataset.json" + validation_dataset_path: null + batch_size: 64 + epochs: 20 + iteration_per_print: 10 + lr: 0.0001 + +critic_config: + # model to be chosen are gp2-large, bart-base, longformer-base-4096 + model: "longformer-base-4096" + model_head_hidden_size: 2048 + model_folder: "./models" \ No newline at end of file diff --git a/chatllama/rlhf/reward.py b/chatllama/rlhf/reward.py new file mode 100644 index 0000000..c734a73 --- /dev/null +++ b/chatllama/rlhf/reward.py @@ -0,0 +1,426 @@ +import json +import os + +import torch +from beartype import beartype +from beartype.typing import Optional, Iterable +from einops.layers.torch import Rearrange +from langchain import OpenAI, LLMChain, PromptTemplate +from torch.utils.data import Dataset, DataLoader +from transformers import GPT2Tokenizer, GPT2Model, BartModel +from transformers import BartTokenizer, BartConfig, AutoModel, AutoTokenizer + +from chatllama.langchain_modules.prompt_templates import REWARD_TEMPLATE +from config import ConfigReward +from utils import TrainingStats + + +class RewardModel(torch.nn.Module): + """Model to be trained to predict the reward for RL. + or to be used as Critic in RL. + + Attributes: + model (torch.nn.Module): Model to be used for the reward model + tokenizer (torch.nn.Module): Tokenizer to be used for the reward model + head (torch.nn.Module): Head to be used for the reward model + config (ConfigReward): Config parameters for the reward model + max_model_tokens (int): Maximum sequence length for the reward model + + Methods: + forward: Forward pass of the model (used by the critic) + save: Save the model + load: Load the model + get_reward: Get the reward for a given input (used by the reward model) + """ + + def __init__(self, config: ConfigReward) -> None: + super().__init__() + # load the model -- add here other models + head_hidden_size = config.model_head_hidden_size + if config.model == "gpt2-large": + self.max_model_tokens = 1024 + self.model = GPT2Model.from_pretrained("gpt2-large") + self.tokenizer = GPT2Tokenizer.from_pretrained( + "gpt2-large", + padding_side="left", + truncation_side="left", + model_max_length=self.max_model_tokens, + ) + self.tokenizer.pad_token = self.tokenizer.eos_token + self.head = torch.nn.Sequential( + torch.nn.Linear(self.model.config.n_embd, head_hidden_size), + torch.nn.ReLU(), + torch.nn.Linear(head_hidden_size, 1), + Rearrange("... 1 -> ..."), + ) + elif config.model == "bart-base": + self.max_model_tokens = 1024 + bart_config = BartConfig.from_pretrained("facebook/bart-base") + bart_config.max_position_embeddings = 2048 + 1024 + self.model = BartModel(bart_config) + self.tokenizer = BartTokenizer.from_pretrained( + "facebook/bart-large", + padding_side="left", + truncation_side="left", + model_max_length=self.max_model_tokens, + ) + self.tokenizer.pad_token = self.tokenizer.eos_token + self.head = torch.nn.Sequential( + torch.nn.Linear(768, head_hidden_size), + torch.nn.ReLU(), + torch.nn.Linear(head_hidden_size, 1), + Rearrange("... 1 -> ..."), + ) + elif config.model == "longformer-base-4096": + self.max_model_tokens = 4096 + self.model = AutoModel.from_pretrained( + "allenai/longformer-base-4096" + ) + self.tokenizer = AutoTokenizer.from_pretrained( + "allenai/longformer-base-4096", + padding_side="left", + truncation_side="left", + model_max_length=self.max_model_tokens, + ) + self.tokenizer.eos_token = self.tokenizer.pad_token + self.head = torch.nn.Sequential( + torch.nn.Linear(768, head_hidden_size), + torch.nn.ReLU(), + torch.nn.Linear(head_hidden_size, 1), + Rearrange("... 1 -> ..."), + ) + else: + raise ValueError(f"model {config.model} not supported") + # store config + self.config = config + if os.path.exists(config.model_folder) is False: + os.mkdir(config.model_folder) + else: + self.load() + # freeze model parameters (only train the head) + for param in self.model.parameters(): + param.requires_grad = False + # move model to device + self.model.to(config.device) + self.head.to(config.device) + + @beartype + def parameters( + self, + ) -> Iterable[torch.nn.Parameter]: + """Return the parameters of the reward model""" + for p in self.model.parameters(): + yield p + for p in self.head.parameters(): + yield p + + @beartype + def forward( + self, output_sequence: torch.Tensor, output_sequence_mask: torch.Tensor + ) -> torch.Tensor: + """Generate the sequence of rewards for the given output sequence + what is the quality of the output sequence tokens? + + Args: + output_sequence (torch.Tensor): The sequence of tokens to be + evaluated + output_sequence_mask (torch.Tensor): Mask for the attention + + Returns: + torch.Tensor: Rewards for the given output sequence + """ + output = self.model( + output_sequence, attention_mask=output_sequence_mask + ) + # What if the output_sequence is longer than the max context of + # the model? + rewards = self.head(output.last_hidden_state) + if self.config.debug: + print("RewardModel.forward") + print("output_sequence.shape", output_sequence.shape) + print("output_sequence", output_sequence) + print("reward.shape", rewards.shape) + print("reward", rewards) + return rewards + + @beartype + def get_reward( + self, output_sequence: torch.Tensor, output_sequence_mask: torch.Tensor + ) -> torch.Tensor: + """Get the reward for the given output sequence + + Args: + output_sequence (torch.Tensor): The concatenation of initial input + and actor output as tokens + output_sequence_mask (torch.Tensor): Mask for the attention + """ + rewards = self.forward(output_sequence, output_sequence_mask) + return rewards[:, -1] + + @beartype + def load(self, path: Optional[str] = None) -> None: + """Load the model from the path + + Args: + path (str): path to the model + """ + if path is None: + path = self.config.model_folder + "/" + self.config.model + ".pt" + if os.path.exists(self.config.model_folder) is False: + os.makedirs(self.config.model_folder) + print( + f"Model folder does not exist. Creating it," + f"and returning without loading the model:\n{path}" + ) + return + # load the model and the tokenizer + if os.path.exists(path) is False: + print( + f"Impossible to load the model:\n{path}\n" + f"The path doesn't exist." + ) + return + model_dict = torch.load(path) + self.model.load_state_dict(model_dict["model"]) + self.head.load_state_dict(model_dict["head"]) + + @beartype + def save(self, path: Optional[str] = None) -> None: + """Save the model to the path + + Args: + path (Optional[str], optional): Path to store the model. + Defaults to None. + """ + if path is None: + path = self.config.model_folder + "/" + self.config.model + ".pt" + if os.path.exists(self.config.model_folder) is False: + os.makedirs(self.config.model_folder) + torch.save( + {"model": self.model.state_dict(), "head": self.head.state_dict()}, + path, + ) + + +# just to keep namings consistent +CriticModel = RewardModel + + +class RewardDataset(Dataset): + """Dataset class for the reward model + read a json file with the following format: + [ + { + "user_input": "...", + "completion": "...", + "score": ... + }, + ... + ] + Where: + user_input: the initial input of the user + completion: the completion generated by the model + score: the score given by the user to the completion (or by the LLM) + """ + + def __init__(self, path: str) -> None: + print(f"Loading dataset from {path}") + with open(path, "r") as f: + self.data = list(json.load(f)) + print(f"Loaded {len(self.data)} samples") + + def __getitem__(self, idx: int): + user_input = self.data[idx]["user_input"] + completion = self.data[idx]["completion"] + item = tuple([user_input, completion]) + return item + + def __len__( + self, + ): + return len(self.data) + + +class RewardTrainer: + """Reward class to train the reward model + + Args: + config (ConfigModel): Config parameters for the model + + Attributes: + model (RewardModel): Reward model + config (ConfigModel): Config parameters for the model + optimizer (torch.optim): Optimizer for the model + loss (torch.nn): Loss function for the model + + Methods: + train: Train the reward model + generate_user_input: Generate the user input for the LLM to evaluate a + couple, (user_input, completion) and assing a score + distillate: Parse the dataset and assign scores using LLMs + """ + + def __init__(self, config: ConfigReward) -> None: + self.model = RewardModel(config) + self.config = config + self.optimizer = torch.optim.Adam( + self.model.parameters(), lr=config.lr + ) + self.loss_function = torch.nn.MSELoss() + if not os.path.exists("./models"): + os.mkdir("./models") + self.training_stats = TrainingStats() + self.validation_flag = False + if config.validation_dataset_path is not None: + self.validation_flag = True + + openai_llm = OpenAI( + model_name=self.config.llm_model, + temperature=self.config.llm_temperature, + max_tokens=self.config.llm_max_tokens, + ) + prompt_template = PromptTemplate(**REWARD_TEMPLATE) + self.llm = LLMChain(llm=openai_llm, prompt=prompt_template) + + def distillate( + self, + ): + """Parse the dataset and assign scores using LLMs + then save back the dataset with the uploaded scores + """ + # load the dataset + with open(self.config.train_dataset_path, "r") as f: + train_data = json.load(f) + # for each element of the dataset, assing a score. + for i, data in enumerate(train_data): + if data.get("score", None) is None: + prompt_tokens = ( + data["user_input"] + + data["completion"] + + self.llm.prompt.template + ) + prompt_len = int(len(prompt_tokens.split(" ")) / 0.75) + # 80% of the max length as safety margin + if prompt_len > self.config.llm_max_tokens * 0.8: + print( + f"The prompt of the data {i} is too long\n" + f"tokens: {prompt_len}\n" + f"max_tokens: {self.config.llm_max_tokens * 0.8}" + ) + continue + score = self.llm.run( + user_input=data["user_input"], + completion=data["completion"], + ).strip() + # TODO extract from score the float value with a regex + score = score.split(" ")[0] + try: + score = float(score) + except Exception: + print( + f"The score returned by the LLM for the" + f"data, {i}, is not a float float:\n{score}" + ) + continue + data["score"] = score + # save the dataset back + with open(self.config.train_dataset_path, "w") as f: + json.dump(train_data, f) + + def train( + self, + ) -> None: + """Train the reward model""" + print("Start Training the Reward Model") + # get config parameters + train_dataset_path = self.config.train_dataset_path + validation_dataset_path = self.config.validation_dataset_path + batch_size = self.config.batch_size + epochs = self.config.epochs + device = self.config.device + + # create dataloaders + train_dataset = RewardDataset(train_dataset_path) + train_dataloader = DataLoader(train_dataset, batch_size=batch_size) + if self.validation_flag: + eval_dataset = RewardDataset(validation_dataset_path) + validation_dataloader = DataLoader( + eval_dataset, batch_size=batch_size + ) + iteration_per_print = self.config.iteration_per_print + + # compute the number of iterations + n_iter = int(len(train_dataset) / batch_size) + + # traing loop + for epoch in range(epochs): + self.model.train() + for i, inputs in enumerate(train_dataloader): + + input_text = inputs["user_input"] + inputs["completion"] + # tokenizer (placed here instead of dataset class) + input_tokens = self.model.tokenizer( + input_text, padding=True, truncation=True + ) + + score = None # TODO: load the score + + # TODO: check on the length of the input tokens if they are + # too many it can create problems + output = torch.tensor(score, dtype=torch.float32).to(device) + + # forward pass + est_output = self.model.get_reward( + input_tokens["input_ids"].to(device), + input_tokens["attention_mask"].to(device), + ) + + loss = self.loss_function(est_output, output) + self.training_stats.training_loss.append(loss.item()) + + # backward pass + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # print progress + if i % iteration_per_print == 0: + print( + f"Epoch: {epoch+1}/{epochs}, " + f"Iteration: {i+1}/{n_iter}, " + f"Training Loss: {loss.item()}" + ) + print( + "prediction", + est_output.cpu().detach().numpy(), + "target", + score.cpu().numpy(), + ) + if self.validation_flag: + self.model.eval() + for i, (text, score) in enumerate(validation_dataloader): + # forward pass + input_tokens = self.model.tokenizer( + text, return_tensors="pt", padding=True + ) + input_tokens = input_tokens.to(device) + # TODO: check on the length of the input tokens if they are + # too many it can create problems + output = torch.tensor(score, dtype=torch.float32).to( + device + ) + est_output = self.model.get_reward( + input_tokens["input_ids"], + input_tokens["attention_mask"], + ) + loss = self.loss_function(est_output, output) + self.training_stats.validation_loss.append(loss.item()) + + # print progress + if i % iteration_per_print == 0: + print( + f"Epoch: {epoch+1}/{epochs}, " + f"Iteration: {i+1}/{n_iter}, " + f"Validation Loss: {loss.item()}" + ) + self.model.save() diff --git a/chatllama/rlhf/test.py b/chatllama/rlhf/test.py new file mode 100644 index 0000000..55c8e83 --- /dev/null +++ b/chatllama/rlhf/test.py @@ -0,0 +1,47 @@ +import torch + +from actor import ActorTrainer +from config import Config +from trainer import RLTrainer +from reward import RewardTrainer + + +def test_actor_training(path=None, device=None, debug=False): + config = Config(path=path, device=device, debug=debug) + trainer = ActorTrainer(config.actor) + trainer.train() + trainer.training_stats.plot() + + +def test_reward_training(path=None, device=None, debug=False): + device = torch.device("cuda:0") + config = Config(path=path, device=device, debug=debug) + trainer = RewardTrainer(config.reward) + trainer.train() + trainer.training_stats.plot() + + +def test_rl_trainig(path=None, device=None, debug=False): + device = torch.device("cuda:0") + config = Config(path=path, device=device, debug=debug) + trainer = RLTrainer(config.trainer) + trainer.distillate() + trainer.train() + trainer.training_stats.plot() + + +if __name__ == "__main__": + reward_training = True + rl_training = False + actor_training = False + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + # place here the path to the config.yaml file + config_path = "/home/pierpaolo/Documents/optimapi/ptuning/config.yaml" + + if reward_training: + test_reward_training(path=config_path, device=device) + if rl_training: + test_rl_trainig(path=config_path, device=device) + if actor_training: + test_actor_training(path=config_path, device=device) diff --git a/chatllama/rlhf/trainer.py b/chatllama/rlhf/trainer.py new file mode 100644 index 0000000..594ab72 --- /dev/null +++ b/chatllama/rlhf/trainer.py @@ -0,0 +1,647 @@ +import json +import os +import random +from collections import deque, namedtuple + +import torch +from beartype import beartype +from beartype.typing import Deque, Tuple, List +from einops import rearrange +from torch.utils.data import Dataset, DataLoader + +from actor import ActorModel +from reward import RewardModel, CriticModel +from config import ConfigReward, ConfigActor, Config +from utils import TrainingStats, ConversationLog + + +class ActorCritic(torch.nn.Module): + """Actor Critic class stores both the actor and the critic models + and it generates values and action for given sequences during the training + of the actor. + + Attributes: + actor (ActorModel): Actor model + critic (CriticModel): Critic model + debug (bool): enable prints for Debugging + + Methods: + forward: given a sequence returns action logits and values (used + to evaluate the actor during training) + generate: given a sequence returns action, action logits, values + sequences and sequences masks (used to generate new sequences + during acting phase) + """ + + def __init__( + self, actor_config: ConfigActor, critic_config: ConfigReward + ) -> None: + super().__init__() + self.actor = ActorModel(actor_config) + self.critic = CriticModel(critic_config) + self.debug = actor_config.debug + + @beartype + def forward( + self, + sequences: torch.Tensor, + sequences_mask: torch.Tensor, + action_len: int, + ) -> Tuple: + """Given the whole sequences, use the actor forward to get the logits + for each token in the sequence and the critic forward to get the + values for each generation step. + + Args: + sequences (torch.Tensor): Sequences composed of [states, actions] + sequence_mask (torch.Tensor): Mask for the sequences + action_length (int): Length of the actions in the sequences + + Returns: + action_logits (torch.Tensor): Logits for the actions in the + sequences + values (torch.Tensor): Values for the actions in the sequences + """ + # use a single forward on the whole sequence + # to get pi(y | x) and ignore predicted output + actions_logits = self.actor.forward(sequences, sequences_mask) + values = self.critic.forward(sequences, sequences_mask) + + # return only logits and values for the actions taken + real_actions_logits = actions_logits[:, -action_len:, :] + real_values = values[:, -action_len:] + + if self.debug: + print("ActorCritic.forward") + print("action_len", action_len) + print("sequences.shape", sequences.shape) + print("sequences", sequences) + print("real_action_logits.shape", actions_logits.shape) + print("real_action_logits", actions_logits) + print("real_values.shape", values.shape) + print("real_values", values) + + return ( + real_actions_logits, + real_values, + ) + + @torch.no_grad() + @beartype + def generate( + self, states: torch.Tensor, state_mask: torch.Tensor + ) -> Tuple: + """Generate actions, actions_logits, values and sequences from states + + Args: + states (torch.Tensor): user inputs + state_mask (torch.Tensor): Mask for the states of the environment + + Returns: + actions (torch.Tensor): Actions generated from the states + actions_logits (torch.Tensor): Logits for the actions generated + from the states (i.e. pi(y | x)) + values (torch.Tensor): Values generated by the critic model + for the actions generated by the actor (i.e. V(x)) + sequences (torch.Tensor): Sequences generated from the states + as [states, actions] + """ + # generate action sequence + actions, sequence = self.actor.generate(states, state_mask) + sequences_mask = sequence != self.actor.tokenizer.pad_token_id + action_len = actions.shape[1] + + # generate actions_logits and values + actions_logits, values = self.forward( + sequence, sequences_mask, action_len + ) + if self.debug: + print("ActorCritic.generate") + print("actions shape", actions.shape) + print("actions", actions) + print("sequence shape", sequence.shape) + print("sequence", sequence) + print("actions_logits shape", actions_logits.shape) + print("actions_logits", actions_logits) + print("values shape", values.shape) + print("values", values) + + return actions, actions_logits, values, sequence, sequences_mask + + +# structure to store the data for each experience +Memory = namedtuple( + "Memory", + [ + "states", + "actions", + "sequences", + "values", + "rewards", + "actions_log_probs", + "sequences_mask", + ], +) + + +class ExperienceDataset(Dataset): + """Dataset to train the actor-critic models""" + + def __init__( + self, + memories: Deque[Memory], + device: torch.device, + ) -> None: + super().__init__() + self.data = list(memories) + self.device = device + + def __len__( + self, + ) -> int: + return len(self.data) + + def __getitem__(self, idx) -> Tuple: + # return the idx-th memory element as a tuple of tensors on the device + item = ( + self.data[idx].states.to(self.device), + self.data[idx].actions.to(self.device), + self.data[idx].sequences.to(self.device), + self.data[idx].values.to(self.device), + self.data[idx].rewards.to(self.device), + self.data[idx].actions_log_probs.to(self.device), + self.data[idx].sequences_mask.to(self.device), + ) + return item + + +class ExamplesSampler: + """Store the prompt to be sampled to generate the examples + read a json file with the following format: + [ + { + "user_input" : "", + } , + ... + ] + Where: + user_input: is the input of the user or directly the input of the user + with the memory preappended (i.e. user_input + memory) + """ + + def __init__( + self, + path: str, + ) -> None: + self.path = path + with open(path, "r") as f: + self.data = json.load(f) + + def sample(self, n: int) -> List: + """Sample n examples from the data + + Args: + n (int): Number of examples to sample + """ + return random.sample(self.data, n) + + +class RLTrainer: + """Train the actor-critic model using RL + + Attributes: + config (Config): Configuration of the trainer + debug (bool): Debug mode + actorcritic (ActorCritic): Actor-critic model + actor_optim (torch.optim): Optimizer for the actor + critic_optim (torch.optim): Optimizer for the critic + reward (RewardModel): Reward model + training_stats (TrainingStats): Class to store training stats + Methods: + train: the training loop that calls the learn function after generating + the experiences. + learn: Learn from a batch of experiences and update the actor and the + critic model. + load_checkpoint: Load the checkpoint of the actor-critic model + save_checkpoint: Save the checkpoint of the actor-critic model + generate_user_input: Generate the user input from the inputs + """ + + def __init__( + self, + config: Config, + ) -> None: + self.config = config + self.debug = config.trainer.debug + + # initialize agent-critic + self.actorcritic = ActorCritic(config.actor, config.critic) + self.actor_optim = torch.optim.Adam( + self.actorcritic.actor.parameters(), lr=config.trainer.actor_lr + ) + self.critic_optim = torch.optim.Adam( + self.actorcritic.critic.parameters(), lr=config.trainer.critic_lr + ) + + # initialize reward model + self.reward = RewardModel(config.reward) + + # initialize class to store training stats + self.training_stats = TrainingStats() + self.conversation_log = ConversationLog() + + # initialize examples sampler + self.example_sampler = ExamplesSampler(config.trainer.examples_path) + + # eps + self.eps = 1e-8 + + # make models directory + if not os.path.exists("./models"): + os.mkdir("./models") + + if not os.path.exists(self.config.trainer.checkpoint_folder): + os.mkdir(self.config.trainer.checkpoint_folder) + + def save_checkpoint( + self, + current_episode: int, + ) -> None: + print(f"Saving checkpoint for episode {current_episode+1}..") + file_name = "rltraining_" + str(current_episode) + ".pt" + checkpoint_folder = self.config.trainer.checkpoint_folder + if os.path.exists(checkpoint_folder) is False: + os.mkdir(checkpoint_folder) + path = checkpoint_folder + "/" + file_name + torch.save( + { + "episode": current_episode, + "actor_state_dict": self.actorcritic.actor.state_dict(), + "critic_state_dict": self.actorcritic.critic.state_dict(), + "actor_optim_state_dict": self.actor_optim.state_dict(), + "critic_optim_state_dict": self.critic_optim.state_dict(), + "training_stats": self.training_stats, + }, + path, + ) + + def load_checkpoint( + self, + ) -> int: + # get all the files name in the checkpoint folder and take the one + # with the highest epoch + checkpoint_folder = self.config.trainer.checkpoint_folder + if os.path.exists(checkpoint_folder) is False: + os.mkdir(checkpoint_folder) + print( + f"Checkpoint folder {checkpoint_folder} does not exist.\n" + f"No checkpoint will be loaded." + ) + return + files = os.listdir(checkpoint_folder) + episodes = [int(f.split("_")[1].split(".")[0]) for f in files] + if len(episodes) == 0: + return 0 + max_episode = max(episodes) + print(f"Loading checkpoint for episode {max_episode+1}..") + file_name = "rltraining_" + str(max_episode) + ".pt" + path = checkpoint_folder + "/" + file_name + checkpoint = torch.load(path) + self.actorcritic.actor.load_state_dict(checkpoint["actor_state_dict"]) + self.actorcritic.critic.load_state_dict( + checkpoint["critic_state_dict"] + ) + self.actor_optim.load_state_dict(checkpoint["actor_optim_state_dict"]) + self.critic_optim.load_state_dict( + checkpoint["critic_optim_state_dict"] + ) + self.trainign_stats = checkpoint["training_stats"] + self.actorcritic.actor.to(self.config.trainer.device) + self.actorcritic.critic.to(self.config.trainer.device) + return max_episode + 1 # return the next episode to train + + @beartype + def learn(self, memories: Deque[Memory]) -> None: + """Train the agent-critic model using RL: + - for each batch of episodes, compute action logits and values + - then compare action logits probs with memories one and values with + rewards to compute the PPO loss and update the actor-critic model + """ + print("Start to Learn...") + + # get parameters + epochs = self.config.trainer.epochs + actor_eps_clip = self.config.trainer.actor_eps_clip + critic_eps_clip = self.config.trainer.critic_eps_clip + beta_s = self.config.trainer.beta_s + batch_size = self.config.trainer.batch_size + device = self.config.trainer.device + + # create dataset from memories + dataloader = DataLoader( + ExperienceDataset(memories, device), batch_size=batch_size + ) + + # train agent-critic + self.actorcritic.train() + for epoch in range(epochs): + for i, ( + states, + old_actions, + sequences, + old_values, + rewards, + old_actions_log_probs, + sequences_mask, + ) in enumerate(dataloader): + + # print + print( + "Epoch", + epoch + 1, + "of", + epochs, + "Data", + i + 1, + "of", + int(len(dataloader) / batch_size), + ) + + if self.debug: + print("RLTrainer.learn()") + print("memory states shapes are: ") + print("states shape", states.shape) + print("old_actions shape", old_actions.shape) + print("sequences shape", sequences.shape) + print("old_values shape", old_values.shape) + print("rewards shape", rewards.shape) + print( + "old_actions_log_probs shape", + old_actions_log_probs.shape, + ) + # reshaping rewards to match [b, s] shape + rewards = rearrange(rewards, "b -> b 1") + + # get actions len + actions_len = old_actions.shape[-1] + + # get actor critic forward + actions_logits, values = self.actorcritic.forward( + sequences, sequences_mask, actions_len + ) + + # get action log prob + actions_prob = ( + torch.softmax(actions_logits, dim=-1).max(dim=-1).values + ) + actions_log_prob = torch.log(actions_prob + self.eps) + + # compute entropy + entropies = (actions_prob * actions_log_prob).sum(dim=-1) + + # compute KL divergence + kl_div_loss = ( + (actions_prob * (old_actions_log_probs - actions_log_prob)) + .sum(dim=-1) + .mean() + ) + + # compute PPO Loss -- Whan dimensions are different + # (especially the values and the probs are + # multiplied directly with the reward) + ratios = (actions_log_prob - old_actions_log_probs).exp() + advantages = rewards - old_values + # normalize advantages + advantages = (advantages - advantages.mean(dim=-1)) / ( + advantages.std() + self.eps + ) + surr1 = advantages * ratios + surr2 = ( + torch.clamp(ratios, 1 - actor_eps_clip, 1 + actor_eps_clip) + * advantages + ) + policy_loss = -torch.min(surr1, surr2) - beta_s * entropies + policy_loss = policy_loss.mean() + loss = policy_loss + kl_div_loss + # check if loss item is nan + if torch.isnan(loss): + raise ValueError("Loss is nan") + print("loss", loss.item()) + + if self.debug: + print("values", values) + print("old_values", old_values) + print("rewards", rewards) + print("ratios", ratios) + print("advantages", advantages) + print("entropies", entropies) + + # update actor with loss + self.actor_optim.zero_grad() + loss.backward() + self.actor_optim.step() + + torch.cuda.synchronize(device) + + # compute value loss + value_loss_clipped = old_values + (values - old_values).clamp( + -critic_eps_clip, critic_eps_clip + ) + value_loss1 = (value_loss_clipped - rewards) ** 2 + value_loss2 = (values - rewards) ** 2 + value_loss = torch.max(value_loss1, value_loss2).mean() + if torch.isnan(value_loss): + raise ValueError("Value loss is nan") + print("value_loss", value_loss.item()) + + # upate critic with loss + self.critic_optim.zero_grad() + value_loss.backward() + self.critic_optim.step() + + self.training_stats.training_loss.append( + loss.detach().cpu().item() + ) + self.training_stats.value_loss.append( + value_loss.detach().cpu().item() + ) + + self.actorcritic.eval() + print("End Learning") + + def train( + self, + ) -> None: + # initialize settings + num_episodes = self.config.trainer.num_episodes + max_timesteps = self.config.trainer.max_timesteps + num_examples = self.config.trainer.num_examples + update_timesteps = self.config.trainer.update_timesteps + batch_size = self.config.trainer.batch_size + update_checkpoint = self.config.trainer.update_checkpoint + device = self.config.trainer.device + + print("Start RL Training") + # check dimensions consistency + # at each time step num_examples memories are generated + number_of_memories_per_learn_iteration = ( + num_examples * update_timesteps + ) + # the number of memories must be a multiple of the batch size + assert ( + number_of_memories_per_learn_iteration % batch_size == 0 + ), "The number of memories must be a multiple of the batch size" + # the total number of timesteps is + total_number_of_timesteps = num_episodes * max_timesteps + # the update_timesteps must be a multiple + # of the total number of timesteps + assert total_number_of_timesteps % update_timesteps == 0, ( + "The number of timesteps (num_episodes*max_timesteps)" + "must be a multiple of the update_timesteps" + ) + + # initialize memories + memories = deque([]) + + # loop over episodes and timesteps + current_time = 0 + checkpoint_counter = 0 + current_episode = self.load_checkpoint() + current_learn_counter = 0 + + self.actorcritic.eval() + for eps in range(current_episode, num_episodes): + for timestep in range(max_timesteps): + + print( + f"Episode: {eps + 1} of {num_episodes}, " + f"Timestep: {timestep + 1} of {max_timesteps}", + ) + + # counter used to count timesteps into memory + current_time += 1 + + # sample num_examples examples from example dataset + inputs = self.example_sampler.sample(num_examples) + + # tokenize examples + tokenized_inputs = self.actorcritic.actor.tokenizer( + inputs, padding=True, return_tensors="pt" + ) + if self.debug: + print("RLTrainer.train()") + print("tokenized inputs", tokenized_inputs) + # states are [batch_size, seq_len_of_states] + states = tokenized_inputs["input_ids"].to(device) + states_mask = tokenized_inputs["attention_mask"].to(device) + + # generate prompts + # actions --> output produced by the actor head in response + # of the state(input) [batch_size, len_of_actions] + # actions_logits --> logits of the actions + # [batch_size, len_of_actions, vocab_size] + # values --> output produced by the critic for each action + # [batch_size, len_of_actions] + # sequence --> (state, actions) + # [batch_size, len_of_actions + seq_len_of_states] = + # [batch_size, seq_len] + ( + actions, + actions_logits, + values, + sequences, + sequences_mask, + ) = self.actorcritic.generate(states, states_mask) + + # from action logits to action log probs + action_prob = ( + torch.softmax(actions_logits, dim=-1).max(dim=-1).values + ) + actions_log_probs = torch.log(action_prob + self.eps) + + completions = [ + self.actorcritic.actor.tokenizer.decode(action) + for i, action in enumerate(actions) + ] + if self.debug: + print("RLTrainer.train()") + print("completions:") + for i, completion in enumerate(completions): + print(i, completion) + print("") + + # compute reward for the completion + # the reward must take into account the answer quality wrt to + # the initial request given + # and must be tokenized again + task_responses = [] + for input, completion in zip(inputs, completions): + task_response = input + "\n" + completion + task_responses.append(task_response) + if self.debug: + print("RLTrainer.train()") + print("task_responses:") + for i, task_response in enumerate(task_responses): + print(i, task_response) + print("") + tokenized_responses = self.reward.tokenizer( + task_responses, padding=True, return_tensors="pt" + ) + rewards = self.reward.get_reward( + tokenized_responses["input_ids"].to(device), + tokenized_responses["attention_mask"].to(device), + ) + + # store memories of the episode / timestep + for i in range(states.shape[0]): + memories.append( + Memory( + *map( + lambda x: x.detach().cpu(), + ( + states[i, :], + actions[i, :], + sequences[i, :], + values[i, :], + rewards[i], + actions_log_probs[i, :], + sequences_mask[i, :], + ), + ) + ) + ) + + # log the memories in the conversation log + for i in range(states.shape[0]): + self.conversation_log.add_conversation( + inputs[i], + completions[i], + rewards[i].detach().cpu(), + current_learn_counter, + ) + + # learn from memories + print( + f"Learning counter: {current_time} of {update_timesteps}" + ) + if (current_time % update_timesteps == 0) and ( + current_time != 0 + ): + checkpoint_counter += 1 + self.conversation_log.show(current_learn_counter) + self.learn(memories) + memories.clear() + current_time = 0 + current_learn_counter += 1 + + if (checkpoint_counter % update_checkpoint == 0) and ( + checkpoint_counter != 0 + ): + self.save_checkpoint(eps) + checkpoint_counter = 0 + + self.actorcritic.critic.save() + self.actorcritic.actor.save() + # print("Show conversations log") + # self.conversation_log.show() + print("End RL Training") diff --git a/chatllama/rlhf/utils.py b/chatllama/rlhf/utils.py new file mode 100644 index 0000000..68dbcf4 --- /dev/null +++ b/chatllama/rlhf/utils.py @@ -0,0 +1,143 @@ +import json +from beartype import beartype +from beartype.typing import Optional +from plotly import graph_objects as go + + +class TrainingStats: + """Training statistics + + Attributes: + training_loss (List): List of training losses + training_accuracy (List): List of training accuracies + value_loss (List): List of value losses + validation_loss (List): List of validation losses + validation_accuracy (List): List of validation accuracies + """ + + def __init__(self): + self.training_loss = [] + self.training_accuracy = [] + self.value_loss = [] + self.validation_loss = [] + self.validation_accuracy = [] + + def plot(self): + """Plot the training statistics using plotly""" + fig = go.Figure() + if len(self.training_loss) > 0: + fig.add_trace( + go.Scatter(y=self.training_loss, name="Training loss") + ) + if len(self.training_accuracy) > 0: + fig.add_trace( + go.Scatter(y=self.training_accuracy, name="Training accuracy") + ) + if len(self.value_loss) > 0: + fig.add_trace(go.Scatter(y=self.value_loss, name="Value loss")) + if len(self.validation_loss) > 0: + fig.add_trace( + go.Scatter(y=self.validation_loss, name="Validation loss") + ) + if len(self.validation_accuracy) > 0: + fig.add_trace( + go.Scatter( + y=self.validation_accuracy, name="Validation accuracy" + ) + ) + fig.update_layout( + showlegend=True, xaxis_type="log", xaxis_title="steps" + ) + fig.show() + + +class ConversationLog: + """Save the conversation: + (user input, model output, rewards and learn_counter) + during the RL training loop. Additionally, in order to be able to compare + the initial dataset of answers to the prompts, we store also the original + performance of the generation: + (generation_input, generation_output, generation_reward) + """ + + def __init__(self): + self.conversation = [] + + @beartype + def add_conversation( + self, + user_input: str, + model_output: str, + reward: float, + learn_counter: int, + previous_reward: float, + previous_completion: str, + ): + """Add a conversation to the log + + Args: + user_input (str): User input / initial prompt + model_output (str): Completion of the LLM model + reward (float): Reward of the reward model assigned to the output + learn_counter (int): Number of the learning iteration to + distinguish the conversations that happens at different + points of the training loop + previous_reward (float): Reward of the reward model assigned to + the output of original dataset + previous_completion (str): Completion of the LLM model of the + original dataset + """ + self.conversation.append( + { + "user_input": user_input, + "model_output": model_output, + "reward": reward, + "learn_counter": learn_counter, + "previous_reward": previous_reward, + "previous_completion": previous_completion, + } + ) + + def save(self, path: Optional[str] = "./conversation.json"): + with open(path, "r") as f: + conversation = json.load(f) + conversation.extend(self.conversation) + with open(path, "w") as f: + json.dump(conversation, f) + + def load(self, path: Optional[str] = "./conversation.json"): + with open(path, "r") as f: + self.conversation = json.load(f) + + def show(self, current_iteration: int = None): + """Show the conversation log + + Args: + current_iteration (int): Current iteration of the training loop, + if not None, print only the conversations that happened at + + """ + for i, c in enumerate(self.conversation): + if current_iteration is None: + print( + f"##########################################\n" + f"Conversation {i} at learn_counter " + f"{c['learn_counter']}\n" + f"##########################################\n" + f"## User Input:\n\n{c['user_input']}\n\n" + f"## Model Output:\n\n{c['model_output']}\n\n" + f"## Reward: {c['reward']}\n\n" + f"## Previous Reward: {c['previous_reward']}\n\n" + ) + else: + if current_iteration == c["learn_counter"]: + print( + f"##########################################\n" + f"Conversation {i} at learn_counter " + f"{c['learn_counter']}\n" + f"##########################################\n" + f"## User Input:\n\n{c['user_input']}\n\n" + f"## Model Output:\n\n{c['model_output']}\n\n" + f"## Reward: {c['reward']}\n\n" + f"## Previous Reward: {c['previous_reward']}\n\n" + ) diff --git a/generate_dataset.py b/generate_dataset.py new file mode 100644 index 0000000..4d36d2c --- /dev/null +++ b/generate_dataset.py @@ -0,0 +1,59 @@ +from langchain import OpenAI, LLMChain, PromptTemplate +from langchain.chains.conversation.memory import ( + ConversationalBufferWindowMemory, +) + +from chatllama.langchain_modules.prompt_templates import ( + PERSON_CHATBOT_TEMPLATE, + AI_CHATBOT_TEMPLATE, +) + + +CONVERSATION_LENGTH = 20 + + +def create_conversation(human_agent: LLMChain, bot_agent: LLMChain): + conversation = [] + chatbot_output = "" + for i in range(CONVERSATION_LENGTH): + # Human agent goes first + human_output = human_agent.run(chatbot_input=chatbot_output) + conversation.append(f"Human: {human_output}") + chatbot_output = bot_agent.run(human_input=human_output) + conversation.append(f"AI: {chatbot_output}") + return "\n".join(conversation) + + +def build_agents(): + # be aware that too long completions will not fit the sequence length + # of possible critic or reward models ... + llm = OpenAI(max_tokens=2048, temperature=0.7) + human_template = PromptTemplate(**PERSON_CHATBOT_TEMPLATE) + human_agent = LLMChain( + llm=llm, + prompt=human_template, + memory=ConversationalBufferWindowMemory(k=4), + ) + bot_template = PromptTemplate(**AI_CHATBOT_TEMPLATE) + bot_agent = LLMChain( + llm=llm, + prompt=bot_template, + memory=ConversationalBufferWindowMemory(k=4), + ) + return human_agent, bot_agent + + +def main(): + from argparse import ArgumentParser + + parser = ArgumentParser() + parser.add_argument("--num_conversations", type=int, default=1000) + parser.add_argument("--output_file", type=str, default="conversations.txt") + args = parser.parse_args() + conversations = [] + for conv in range(args.num_conversations): + human_agent, bot_agent = build_agents() + conversation = create_conversation(human_agent, bot_agent) + conversations.append(conversation) + with open(args.output_file, "w") as f: + f.write("\n\nNEW CONVERSATION\n\n".join(conversations)) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..dd2e65b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +pyllama +fairscale +langchain +beartype +einops +plotly diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..f48cbc4 --- /dev/null +++ b/setup.py @@ -0,0 +1,59 @@ +from setuptools import setup +import os + +here = os.path.dirname(os.path.realpath(__file__)) +HAS_CUDA = os.system("nvidia-smi > /dev/null 2>&1") == 0 + +VERSION = "0.0.4" +DESCRIPTION = "ChatLLaMA: Open and Efficient Foundation Language Models Runnable In A Single GPU" + +packages = [ + "chatllama", +] + + +def read_file(filename: str): + try: + lines = [] + with open(filename) as file: + lines = file.readlines() + lines = [line.rstrip() for line in lines if not line.startswith('#')] + return lines + except: + return [] + + +def package_files(ds): + paths = [] + for d in ds: + for (path, directories, filenames) in os.walk(d): + for filename in filenames: + if '__pycache__' not in str(filename): + paths.append(str(os.path.join(path, filename))[len('chatllama/'):]) + return paths + +extra_files = package_files(['chatllama/']) + + +setup( + name="chatllama", + version=VERSION, + author_email="", + description=DESCRIPTION, + long_description=open("README.md", "r", encoding="utf-8").read(), + long_description_content_type="text/markdown", + install_requires=read_file(f"{here}/requirements.txt"), + keywords=[ + "ChatLLaMA", "LLaMA" + ], + classifiers=[ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + ], + packages=packages, + package_data={"chatllama": extra_files}, + url="https://github.com/juncongmoo/chatllama" +)