Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge generator.py and generate/full.py #432

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ To generate text predictions, you need to download the model weights. **If you d
Run inference:

```bash
python generate.py --prompt "Hello, my name is"
python generate/full.py --prompt "Hello, my name is"
```

This will run the 7B model and require ~26 GB of GPU memory (A100 GPU).
Expand All @@ -86,14 +86,14 @@ This will run the 7B model and require ~26 GB of GPU memory (A100 GPU).

### Run Lit-LLaMA on consumer devices

On GPUs with `bfloat16` support, the `generate.py` script will automatically convert the weights and consume about ~14 GB.
On GPUs with `bfloat16` support, the `full.py` script will automatically convert the weights and consume about ~14 GB.
For GPUs with less memory, or ones that don't support `bfloat16`, enable quantization (`--quantize llm.int8`):

```bash
python generate.py --quantize llm.int8 --prompt "Hello, my name is"
python generate/full.py --quantize llm.int8 --prompt "Hello, my name is"
```

See `python generate.py --help` for more options.
See `python generate/full.py --help` for more options.

You can also use GPTQ-style int4 quantization, but this needs conversions of the weights first:

Expand Down
170 changes: 0 additions & 170 deletions generate.py

This file was deleted.

2 changes: 1 addition & 1 deletion generate/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from generate import generate
from generate.generate_utils import generate
from lit_llama import Tokenizer
from lit_llama.adapter import LLaMA
from lit_llama.utils import lazy_load, llama_model_lookup, quantization
Expand Down
2 changes: 1 addition & 1 deletion generate/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from generate import generate
from generate.generate_utils import generate
from lit_llama import Tokenizer
from lit_llama.adapter import LLaMA
from lit_llama.utils import lazy_load, llama_model_lookup, quantization
Expand Down
23 changes: 14 additions & 9 deletions generate/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
sys.path.append(str(wd))

from lit_llama import LLaMA, Tokenizer
from lit_llama.utils import quantization
from lit_llama.utils import quantization, lazy_load, llama_model_lookup
from scripts.prepare_alpaca import generate_prompt
from generate import generate

from generate.generate_utils import generate

def main(
prompt: str = "Hello, my name is",
Expand All @@ -28,6 +27,7 @@ def main(
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
model_size: str = "7B",
quantize: Optional[str] = None,
instruction_tuning: Optional[bool] = False
) -> None:
"""Generates text samples based on a pre-trained LLaMA model and tokenizer.

Expand All @@ -44,6 +44,7 @@ def main(
quantize: Whether to quantize the model and using which method:
``"llm.int8"``: LLM.int8() mode,
``"gptq.int4"``: GPTQ 4-bit mode.
instruction_tuning: Whether to regenerate sample in instruction turning format.
"""
if not checkpoint_path:
checkpoint_path = Path(f"checkpoints/lit-llama/{model_size}/lit-llama.pth")
Expand All @@ -56,19 +57,23 @@ def main(
print("Loading model ...", file=sys.stderr)
t0 = time.time()

with fabric.init_module(empty_init=True), quantization(mode=quantize):
model = LLaMA.from_name(model_size)
with lazy_load(checkpoint_path) as checkpoint:
name = llama_model_lookup(checkpoint)

with fabric.init_module(empty_init=True), quantization(mode=quantize):
model = LLaMA.from_name(name)

checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)
model.load_state_dict(checkpoint)
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)

model.eval()
model = fabric.setup(model)

tokenizer = Tokenizer(tokenizer_path)
sample = {"instruction": prompt, "input": input}
prompt = generate_prompt(sample)

if instruction_tuning:
sample = {"instruction": prompt, "input": input}
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
prompt_length = encoded.size(0)

Expand Down
77 changes: 77 additions & 0 deletions generate/generate_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import lightning as L
import torch
from typing import Optional
from lit_llama import LLaMA

@torch.no_grad()
def generate(
model: LLaMA,
idx: torch.Tensor,
max_new_tokens: int,
*,
max_seq_length: Optional[int] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
eos_id: Optional[int] = None,
) -> torch.Tensor:
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.

The implementation of this function is modified from A. Karpathy's nanoGPT.

Args:
model: The model to use.
idx: Tensor of shape (T) with indices of the prompt sequence.
max_new_tokens: The number of new tokens to generate.
max_seq_length: The maximum sequence length allowed.
temperature: Scales the predicted logits by 1 / temperature
top_k: If specified, only sample among the tokens with the k highest probabilities
eos_id: If specified, stop generating any more token once the <eos> token is triggered
"""
# create an empty tensor of the expected final shape and fill in the current tokens
T = idx.size(0)
T_new = T + max_new_tokens
if max_seq_length is None:
max_seq_length = min(T_new, model.config.block_size)

device, dtype = idx.device, idx.dtype
# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(T_new, dtype=dtype, device=device)
empty[:T] = idx
idx = empty
input_pos = torch.arange(0, T, device=device)

if idx.device.type == "xla":
import torch_xla.core.xla_model as xm

xm.mark_step()

# generate max_new_tokens tokens
for _ in range(max_new_tokens):
x = idx.index_select(0, input_pos).view(1, -1)

# forward
logits = model(x, max_seq_length, input_pos)
logits = logits[0, -1] / temperature

# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits = torch.where(logits < v[[-1]], -float("Inf"), logits)

probs = torch.nn.functional.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)

# advance
input_pos = input_pos[-1:] + 1

if idx.device.type == "xla":
xm.mark_step()

# concatenate the new generation
idx = idx.index_copy(0, input_pos, idx_next)

# if <eos> token is triggered, return the output (stop generation)
if idx_next == eos_id:
return idx[:input_pos] # include the EOS token

return idx
2 changes: 1 addition & 1 deletion generate/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from generate import generate
from lit_llama import Tokenizer, LLaMA
from lit_llama.lora import lora
from lit_llama.utils import lazy_load, llama_model_lookup
from generate.generate_utils import generate
from scripts.prepare_alpaca import generate_prompt

lora_r = 8
Expand Down
6 changes: 3 additions & 3 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
def load_generate_script():
sys.path.append(str(wd))

import generate as generate
from generate import full

return generate
return full


def test_generate():
Expand Down Expand Up @@ -111,7 +111,7 @@ def init_module(self, empty_init):


def test_cli():
cli_path = wd / "generate.py"
cli_path = wd / "generate/full.py"
output = subprocess.check_output([sys.executable, cli_path, "-h"])
output = str(output.decode())
assert "Generates text samples" in output