diff --git a/README.md b/README.md index fd31e36d..1a30a859 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ To generate text predictions, you need to download the model weights. **If you d Run inference: ```bash -python generate.py --prompt "Hello, my name is" +python generate/full.py --prompt "Hello, my name is" ``` This will run the 7B model and require ~26 GB of GPU memory (A100 GPU). @@ -86,14 +86,14 @@ This will run the 7B model and require ~26 GB of GPU memory (A100 GPU). ### Run Lit-LLaMA on consumer devices -On GPUs with `bfloat16` support, the `generate.py` script will automatically convert the weights and consume about ~14 GB. +On GPUs with `bfloat16` support, the `full.py` script will automatically convert the weights and consume about ~14 GB. For GPUs with less memory, or ones that don't support `bfloat16`, enable quantization (`--quantize llm.int8`): ```bash -python generate.py --quantize llm.int8 --prompt "Hello, my name is" +python generate/full.py --quantize llm.int8 --prompt "Hello, my name is" ``` -See `python generate.py --help` for more options. +See `python generate/full.py --help` for more options. You can also use GPTQ-style int4 quantization, but this needs conversions of the weights first: diff --git a/generate.py b/generate.py deleted file mode 100644 index 91a7a6e4..00000000 --- a/generate.py +++ /dev/null @@ -1,170 +0,0 @@ -import sys -import time -import warnings -from pathlib import Path -from typing import Optional - -import lightning as L -import torch - -# support running without installing as a package -wd = Path(__file__).parent.parent.resolve() -sys.path.append(str(wd)) - -from lit_llama import LLaMA, Tokenizer -from lit_llama.utils import lazy_load, llama_model_lookup, quantization - - -@torch.no_grad() -def generate( - model: LLaMA, - idx: torch.Tensor, - max_new_tokens: int, - *, - max_seq_length: Optional[int] = None, - temperature: float = 1.0, - top_k: Optional[int] = None, - eos_id: Optional[int] = None, -) -> torch.Tensor: - """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. - - The implementation of this function is modified from A. Karpathy's nanoGPT. - - Args: - model: The model to use. - idx: Tensor of shape (T) with indices of the prompt sequence. - max_new_tokens: The number of new tokens to generate. - max_seq_length: The maximum sequence length allowed. - temperature: Scales the predicted logits by 1 / temperature - top_k: If specified, only sample among the tokens with the k highest probabilities - eos_id: If specified, stop generating any more token once the token is triggered - """ - # create an empty tensor of the expected final shape and fill in the current tokens - T = idx.size(0) - T_new = T + max_new_tokens - if max_seq_length is None: - max_seq_length = min(T_new, model.config.block_size) - - device, dtype = idx.device, idx.dtype - # create an empty tensor of the expected final shape and fill in the current tokens - empty = torch.empty(T_new, dtype=dtype, device=device) - empty[:T] = idx - idx = empty - input_pos = torch.arange(0, T, device=device) - - if idx.device.type == "xla": - import torch_xla.core.xla_model as xm - - xm.mark_step() - - # generate max_new_tokens tokens - for _ in range(max_new_tokens): - x = idx.index_select(0, input_pos).view(1, -1) - - # forward - logits = model(x, max_seq_length, input_pos) - logits = logits[0, -1] / temperature - - # optionally crop the logits to only the top k options - if top_k is not None: - v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - logits = torch.where(logits < v[[-1]], -float("Inf"), logits) - - probs = torch.nn.functional.softmax(logits, dim=-1) - idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype) - - # advance - input_pos = input_pos[-1:] + 1 - - if idx.device.type == "xla": - xm.mark_step() - - # concatenate the new generation - idx = idx.index_copy(0, input_pos, idx_next) - - # if token is triggered, return the output (stop generation) - if idx_next == eos_id: - return idx[:input_pos] # include the EOS token - - return idx - - -def main( - prompt: str = "Hello, my name is", - *, - num_samples: int = 1, - max_new_tokens: int = 50, - top_k: int = 200, - temperature: float = 0.8, - checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"), - tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), - quantize: Optional[str] = None, -) -> None: - """Generates text samples based on a pre-trained LLaMA model and tokenizer. - - Args: - prompt: The prompt string to use for generating the samples. - num_samples: The number of text samples to generate. - max_new_tokens: The number of generation steps to take. - top_k: The number of top most probable tokens to consider in the sampling process. - temperature: A value controlling the randomness of the sampling process. Higher values result in more random - samples. - checkpoint_path: The checkpoint path to load. - tokenizer_path: The tokenizer path to load. - quantize: Whether to quantize the model and using which method: - ``"llm.int8"``: LLM.int8() mode, - ``"gptq.int4"``: GPTQ 4-bit mode. - """ - assert checkpoint_path.is_file(), checkpoint_path - assert tokenizer_path.is_file(), tokenizer_path - - precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true" - fabric = L.Fabric(devices=1, precision=precision) - - print("Loading model ...", file=sys.stderr) - t0 = time.time() - with lazy_load(checkpoint_path) as checkpoint: - name = llama_model_lookup(checkpoint) - - with fabric.init_module(empty_init=True), quantization(mode=quantize): - model = LLaMA.from_name(name) - - model.load_state_dict(checkpoint) - print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) - - model.eval() - model = fabric.setup(model) - - tokenizer = Tokenizer(tokenizer_path) - encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device) - prompt_length = encoded.size(0) - - L.seed_everything(1234) - for i in range(num_samples): - t0 = time.perf_counter() - y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k) - t = time.perf_counter() - t0 - - model.reset_cache() - print(tokenizer.decode(y)) - tokens_generated = y.size(0) - prompt_length - print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr) - if fabric.device.type == "cuda": - print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr) - - -if __name__ == "__main__": - from jsonargparse import CLI - - torch.set_float32_matmul_precision("high") - warnings.filterwarnings( - # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31 - "ignore", - message="ComplexHalf support is experimental and many operators don't support it yet" - ) - warnings.filterwarnings( - # Triggered in bitsandbytes/autograd/_functions.py:298 - "ignore", - message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization", - ) - CLI(main) diff --git a/generate/adapter.py b/generate/adapter.py index 1fe8af4d..014486ad 100644 --- a/generate/adapter.py +++ b/generate/adapter.py @@ -11,7 +11,7 @@ wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -from generate import generate +from generate.generate_utils import generate from lit_llama import Tokenizer from lit_llama.adapter import LLaMA from lit_llama.utils import lazy_load, llama_model_lookup, quantization diff --git a/generate/adapter_v2.py b/generate/adapter_v2.py index d32db7c0..5449ebe5 100644 --- a/generate/adapter_v2.py +++ b/generate/adapter_v2.py @@ -11,7 +11,7 @@ wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -from generate import generate +from generate.generate_utils import generate from lit_llama import Tokenizer from lit_llama.adapter import LLaMA from lit_llama.utils import lazy_load, llama_model_lookup, quantization diff --git a/generate/full.py b/generate/full.py index 443a75e3..5340ab65 100644 --- a/generate/full.py +++ b/generate/full.py @@ -12,10 +12,9 @@ sys.path.append(str(wd)) from lit_llama import LLaMA, Tokenizer -from lit_llama.utils import quantization +from lit_llama.utils import quantization, lazy_load, llama_model_lookup from scripts.prepare_alpaca import generate_prompt -from generate import generate - +from generate.generate_utils import generate def main( prompt: str = "Hello, my name is", @@ -28,6 +27,7 @@ def main( tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), model_size: str = "7B", quantize: Optional[str] = None, + instruction_tuning: Optional[bool] = False ) -> None: """Generates text samples based on a pre-trained LLaMA model and tokenizer. @@ -44,6 +44,7 @@ def main( quantize: Whether to quantize the model and using which method: ``"llm.int8"``: LLM.int8() mode, ``"gptq.int4"``: GPTQ 4-bit mode. + instruction_tuning: Whether to regenerate sample in instruction turning format. """ if not checkpoint_path: checkpoint_path = Path(f"checkpoints/lit-llama/{model_size}/lit-llama.pth") @@ -56,19 +57,23 @@ def main( print("Loading model ...", file=sys.stderr) t0 = time.time() - with fabric.init_module(empty_init=True), quantization(mode=quantize): - model = LLaMA.from_name(model_size) + with lazy_load(checkpoint_path) as checkpoint: + name = llama_model_lookup(checkpoint) + + with fabric.init_module(empty_init=True), quantization(mode=quantize): + model = LLaMA.from_name(name) - checkpoint = torch.load(checkpoint_path) - model.load_state_dict(checkpoint) + model.load_state_dict(checkpoint) print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) model.eval() model = fabric.setup(model) tokenizer = Tokenizer(tokenizer_path) - sample = {"instruction": prompt, "input": input} - prompt = generate_prompt(sample) + + if instruction_tuning: + sample = {"instruction": prompt, "input": input} + prompt = generate_prompt(sample) encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device) prompt_length = encoded.size(0) diff --git a/generate/generate_utils.py b/generate/generate_utils.py new file mode 100644 index 00000000..1c32831e --- /dev/null +++ b/generate/generate_utils.py @@ -0,0 +1,77 @@ +import lightning as L +import torch +from typing import Optional +from lit_llama import LLaMA + +@torch.no_grad() +def generate( + model: LLaMA, + idx: torch.Tensor, + max_new_tokens: int, + *, + max_seq_length: Optional[int] = None, + temperature: float = 1.0, + top_k: Optional[int] = None, + eos_id: Optional[int] = None, +) -> torch.Tensor: + """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + + The implementation of this function is modified from A. Karpathy's nanoGPT. + + Args: + model: The model to use. + idx: Tensor of shape (T) with indices of the prompt sequence. + max_new_tokens: The number of new tokens to generate. + max_seq_length: The maximum sequence length allowed. + temperature: Scales the predicted logits by 1 / temperature + top_k: If specified, only sample among the tokens with the k highest probabilities + eos_id: If specified, stop generating any more token once the token is triggered + """ + # create an empty tensor of the expected final shape and fill in the current tokens + T = idx.size(0) + T_new = T + max_new_tokens + if max_seq_length is None: + max_seq_length = min(T_new, model.config.block_size) + + device, dtype = idx.device, idx.dtype + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty(T_new, dtype=dtype, device=device) + empty[:T] = idx + idx = empty + input_pos = torch.arange(0, T, device=device) + + if idx.device.type == "xla": + import torch_xla.core.xla_model as xm + + xm.mark_step() + + # generate max_new_tokens tokens + for _ in range(max_new_tokens): + x = idx.index_select(0, input_pos).view(1, -1) + + # forward + logits = model(x, max_seq_length, input_pos) + logits = logits[0, -1] / temperature + + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits = torch.where(logits < v[[-1]], -float("Inf"), logits) + + probs = torch.nn.functional.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype) + + # advance + input_pos = input_pos[-1:] + 1 + + if idx.device.type == "xla": + xm.mark_step() + + # concatenate the new generation + idx = idx.index_copy(0, input_pos, idx_next) + + # if token is triggered, return the output (stop generation) + if idx_next == eos_id: + return idx[:input_pos] # include the EOS token + + return idx diff --git a/generate/lora.py b/generate/lora.py index 38a3cf63..eb595b44 100644 --- a/generate/lora.py +++ b/generate/lora.py @@ -11,10 +11,10 @@ wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -from generate import generate from lit_llama import Tokenizer, LLaMA from lit_llama.lora import lora from lit_llama.utils import lazy_load, llama_model_lookup +from generate.generate_utils import generate from scripts.prepare_alpaca import generate_prompt lora_r = 8 diff --git a/tests/test_generate.py b/tests/test_generate.py index ecbd6afd..68ae18ab 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -16,9 +16,9 @@ def load_generate_script(): sys.path.append(str(wd)) - import generate as generate + from generate import full - return generate + return full def test_generate(): @@ -111,7 +111,7 @@ def init_module(self, empty_init): def test_cli(): - cli_path = wd / "generate.py" + cli_path = wd / "generate/full.py" output = subprocess.check_output([sys.executable, cli_path, "-h"]) output = str(output.decode()) assert "Generates text samples" in output