-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
55 lines (44 loc) · 1.76 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import streamlit as st
import torch
import numpy as np
import tiktoken
from utilities.dataloader import text_to_token_ids, token_ids_to_text
from model.transformer import TransformerModel
from generate import generate
CONFIG = {
"vocab_size": 50257,
"ctx_len": 1024,
"emb_dim": 768,
"n_heads": 12,
"n_layers": 12,
"drop_rate": 0,
"qkv_bias": False
}
torch.manual_seed(123)
model = TransformerModel(CONFIG)
model.load_state_dict(torch.load("model.pth"))
model.eval()
tokenizer = tiktoken.get_encoding("gpt2")
@st.cache_data # Cache results to improve performance
def run_generation(input_text, max_length=50, temperature=1.0, top_k=10):
encoded = text_to_token_ids(input_text, tokenizer)
out = generate(
model=model,
idx=encoded,
max_new_tokens=max_length,
context_size=CONFIG["ctx_len"],
top_k=top_k,
temperature=temperature
)
generated_text = token_ids_to_text(out, tokenizer).strip()
return generated_text
# Set up title and sidebar options
st.title('Story Generator')
start_context = st.text_area('Start Context', value="Once upon a time, there was a little girl named Lily. She loved to play outside in the sunshine.One day,", height=150)
c1, c2, c3 = st.columns(3)
with c1: num_words = st.slider('Max New Words:', min_value=10, max_value=200, step=5, value=50)
with c2: temperature = st.slider('Temperature:', min_value=0.0, max_value=2.0, step=0.1, value=1.0)
with c3: top_k = st.slider('Top K Sampling:', min_value=0, max_value=50, step=5, value=10)
if st.button('Generate'):
output_text = run_generation(start_context, max_length=num_words, temperature=temperature, top_k=top_k)
st.write(f"\nGenerated Text:\n{output_text}")