Skip to content

Commit

Permalink
part 5.2 done
Browse files Browse the repository at this point in the history
  • Loading branch information
hbaghramyan committed Oct 25, 2024
1 parent cd1fd40 commit 33bae2a
Showing 1 changed file with 54 additions and 1 deletion.
55 changes: 54 additions & 1 deletion ch05/01_main-chapter-code/ch05.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import tiktoken
import os
import sys
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

sys.path.insert(0, os.getcwd())

Expand Down Expand Up @@ -203,7 +205,7 @@ def calc_loss_batch(input_batch, target_batch, model, device):
target_batch = target_batch.to(device)
logits = model(input_batch)
loss = torch.nn.functional.cross_entropy(
input=logits.flatten(0, 1), target=targets.flatten()
input=logits.flatten(0, 1), target=target_batch.flatten()
)
return loss

Expand Down Expand Up @@ -283,3 +285,54 @@ def evaluate_model(model, train_loader, val_loader, device, eval_iter):
val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
model.train()
return train_loss, val_loss


def generate_and_print_sample(model, tokenizer, device, start_context):
model.eval()
context_size = model.pos_emb.weight.shape[0]
encoded = text_to_token_ids(start_context, tokenizer).to(device)
with torch.no_grad():
token_ids = generate_text_simple(
model=model, idx=encoded, max_new_tokens=50, context_size=context_size
)
decoded_text = token_ids_to_text(token_ids, tokenizer)
print(decoded_text.replace("\n", " "))
model.train()


torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)
num_epochs = 10
train_losses, val_losses, tokens_seen = train_model_simple(
model,
train_loader,
val_loader,
optimizer,
device,
num_epochs=num_epochs,
eval_freq=5,
eval_iter=5,
start_context="Every effort moves you",
tokenizer=tokenizer,
)


def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
fig, ax1 = plt.subplots(figsize=(5, 3))
ax1.plot(epochs_seen, train_losses, label="Training loss")
ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Loss")
ax1.legend(loc="upper right")
ax1.xaxis.set_major_locator(MaxNLocator(integer=True))
ax2 = ax1.twiny()
ax2.plot(tokens_seen, train_losses, alpha=0)
ax2.set_xlabel("Tokens seen")
fig.tight_layout()
plt.show()


epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))
plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)

0 comments on commit 33bae2a

Please sign in to comment.