forked from donam4rkova/llama3_interpretability_sae
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathllama_3_inference_chat_completion_test.py
136 lines (117 loc) · 4.66 KB
/
llama_3_inference_chat_completion_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import argparse
import logging
import os
from pathlib import Path
import torch
from llama_3.datatypes import SystemMessage, UserMessage
from llama_3_inference import Llama3Inference
from sae import load_sae_model
from utils.cuda_utils import set_torch_seed_for_inference
def parse_arguments() -> argparse.Namespace:
""""""
parser = argparse.ArgumentParser()
parser.add_argument("--llama_model_dir", type=Path, required=True)
parser.add_argument("--sae_model_path", type=Path, default=None)
return parser.parse_args()
def main() -> None:
""""""
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Parse arguments and set up paths
args = parse_arguments()
args.llama_model_dir = args.llama_model_dir.resolve()
llama_tokenizer_path = args.llama_model_dir / "tokenizer.model"
llama_params_path = args.llama_model_dir / "params.json"
llama_model_path = args.llama_model_dir / "consolidated.00.pth"
if args.sae_model_path is not None:
args.sae_model_path = args.sae_model_path.resolve()
# Set up configuration
max_new_tokens = 128
temperature = 0.7
top_p = 0.9
seed = 42
sae_layer_idx = None
sae_h_bias = None
sae_top_k = 64
sae_normalization_eps = 1e-6
sae_dtype = torch.float32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info("#### Starting sample Llama3 Chat Completion")
logging.info("#### Arguments:")
logging.info(f"# llama_model_dir={args.llama_model_dir}")
logging.info(f"# sae_model_path={args.sae_model_path}")
logging.info("#### Configuration:")
logging.info(f"# max_new_tokens={max_new_tokens}")
logging.info(f"# temperature={temperature}")
logging.info(f"# top_p={top_p}")
logging.info(f"# seed={seed}")
logging.info(f"# sae_layer_idx={sae_layer_idx}")
logging.info(f"# sae_h_bias={sae_h_bias}")
logging.info(f"# sae_top_k={sae_top_k}")
logging.info(f"# sae_normalization_eps={sae_normalization_eps}")
logging.info(f"# sae_dtype={sae_dtype}")
logging.info(f"# device={device}")
# Set up CUDA and seed for inference
set_torch_seed_for_inference(seed)
# Load the SAE model if provided and set up the forward fn for the specified sae_layer_idx
sae_layer_forward_fn = None
if args.sae_model_path is not None:
assert sae_layer_idx is not None
sae_model = load_sae_model(
model_path=args.sae_model_path,
sae_top_k=sae_top_k,
sae_normalization_eps=sae_normalization_eps,
device=device,
dtype=sae_dtype,
)
sae_layer_forward_fn = {sae_layer_idx: sae_model.forward}
if sae_h_bias is not None:
logging.info("Setting SAE h_bias...")
h_bias = torch.zeros(sae_model.n_latents)
h_bias[sae_h_bias[0]] = sae_h_bias[1]
h_bias = h_bias.to(sae_dtype).to(device)
sae_model.set_latent_bias(h_bias)
# Initialize the Llama3Inferenence generator
llama_inference = Llama3Inference(
tokenizer_path=llama_tokenizer_path,
params_path=llama_params_path,
model_path=llama_model_path,
device=device,
sae_layer_forward_fn=sae_layer_forward_fn,
)
# Prepare batch for chat completion
logging.info("Generating sample chat completions...")
system_message = SystemMessage(
content="You are a pirate chatbot who always responds in pirate speak!",
)
user_message_1 = UserMessage(content="Who are you?")
user_message_2 = UserMessage(content="What is your purpose?")
user_message_3 = UserMessage(content="Where are you from?")
user_message_4 = UserMessage(content="What is your favorite color?")
message_sequences = [
[system_message, user_message_1],
[system_message, user_message_2],
[system_message, user_message_3],
[system_message, user_message_4],
]
# Generate chat completions and print results iteratively
for next_messages in llama_inference.generate_chat_completions(
message_sequences=message_sequences,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
):
# Clear the console for a more 'commercial LLM web UI' feel
os.system("clear")
# Update each completion with the new message content and print
for i, message in enumerate(next_messages):
print(f"#### Chat Completion {i + 1}: ".ljust(80, "#"))
print(message.content)
print("#" * 80)
logging.info("#### FIN!")
if __name__ == "__main__":
main()