forked from rasbt/LLMs-from-scratch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'upstream/main' into dev
- Loading branch information
Showing
6 changed files
with
262 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "1545a16b-bc8d-4e49-b9a6-db6631e7483d", | ||
"metadata": {}, | ||
"source": [ | ||
"<table style=\"width:100%\">\n", | ||
"<tr>\n", | ||
"<td style=\"vertical-align:middle; text-align:left;\">\n", | ||
"<font size=\"2\">\n", | ||
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n", | ||
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n", | ||
"</font>\n", | ||
"</td>\n", | ||
"<td style=\"vertical-align:middle; text-align:left;\">\n", | ||
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n", | ||
"</td>\n", | ||
"</tr>\n", | ||
"</table>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f3f83194-82b9-4478-9550-5ad793467bd0", | ||
"metadata": {}, | ||
"source": [ | ||
"# Load And Use Finetuned Model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "466b564e-4fd5-4d76-a3a1-63f9f0993b7e", | ||
"metadata": {}, | ||
"source": [ | ||
"This notebook contains minimal code to load the finetuned model that was created and saved in chapter 6 via [ch06.ipynb](ch06.ipynb)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "fd80e5f5-0f79-4a6c-bf31-2026e7d30e52", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"tiktoken version: 0.6.0\n", | ||
"torch version: 2.2.2\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from importlib.metadata import version\n", | ||
"\n", | ||
"pkgs = [\n", | ||
" \"tiktoken\", # Tokenizer\n", | ||
" \"torch\", # Deep learning library\n", | ||
"]\n", | ||
"for p in pkgs:\n", | ||
" print(f\"{p} version: {version(p)}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "ed86d6b7-f32d-4601-b585-a2ea3dbf7201", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from pathlib import Path\n", | ||
"\n", | ||
"finetuned_model_path = Path(\"review_classifier.pth\")\n", | ||
"if not finetuned_model_path.exists():\n", | ||
" print(\n", | ||
" f\"Could not find '{finetuned_model_path}'.\\n\"\n", | ||
" \"Run the `ch06.ipynb` notebook to finetune and save the finetuned model.\"\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "fb02584a-5e31-45d5-8377-794876907bc6", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from previous_chapters import GPTModel\n", | ||
"\n", | ||
"\n", | ||
"BASE_CONFIG = {\n", | ||
" \"vocab_size\": 50257, # Vocabulary size\n", | ||
" \"context_length\": 1024, # Context length\n", | ||
" \"drop_rate\": 0.0, # Dropout rate\n", | ||
" \"qkv_bias\": True # Query-key-value bias\n", | ||
"}\n", | ||
"\n", | ||
"model_configs = {\n", | ||
" \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n", | ||
" \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n", | ||
" \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n", | ||
" \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n", | ||
"}\n", | ||
"\n", | ||
"CHOOSE_MODEL = \"gpt2-small (124M)\"\n", | ||
"\n", | ||
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])\n", | ||
"\n", | ||
"# Initialize base model\n", | ||
"model_size = CHOOSE_MODEL.split(\" \")[-1].lstrip(\"(\").rstrip(\")\")\n", | ||
"model = GPTModel(BASE_CONFIG)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "f1ccf2b7-176e-4cfd-af7a-53fb76010b94", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"\n", | ||
"# Convert model to classifier as in section 6.5 in ch06.ipynb\n", | ||
"num_classes = 2\n", | ||
"model.out_head = torch.nn.Linear(in_features=BASE_CONFIG[\"emb_dim\"], out_features=num_classes)\n", | ||
"\n", | ||
"# Then load pretrained weights\n", | ||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | ||
"model.load_state_dict(torch.load(\"review_classifier.pth\", map_location=device))\n", | ||
"model.eval();" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "a1fd174e-9555-46c5-8780-19b0aa4f26e5", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import tiktoken\n", | ||
"\n", | ||
"tokenizer = tiktoken.get_encoding(\"gpt2\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "2a4c0129-efe5-46e9-bb90-ba08d407c1a2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# This function was implemented in ch06.ipynb\n", | ||
"def classify_review(text, model, tokenizer, device, max_length=None, pad_token_id=50256):\n", | ||
" model.eval()\n", | ||
"\n", | ||
" # Prepare inputs to the model\n", | ||
" input_ids = tokenizer.encode(text)\n", | ||
" supported_context_length = model.pos_emb.weight.shape[1]\n", | ||
"\n", | ||
" # Truncate sequences if they too long\n", | ||
" input_ids = input_ids[:min(max_length, supported_context_length)]\n", | ||
"\n", | ||
" # Pad sequences to the longest sequence\n", | ||
" input_ids += [pad_token_id] * (max_length - len(input_ids))\n", | ||
" input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0) # add batch dimension\n", | ||
"\n", | ||
" # Model inference\n", | ||
" with torch.no_grad():\n", | ||
" logits = model(input_tensor)[:, -1, :] # Logits of the last output token\n", | ||
" predicted_label = torch.argmax(logits, dim=-1).item()\n", | ||
"\n", | ||
" # Return the classified result\n", | ||
" return \"spam\" if predicted_label == 1 else \"not spam\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"id": "1e26862c-10b5-4a0f-9dd6-b6ddbad2fc3f", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"spam\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"text_1 = (\n", | ||
" \"You are a winner you have been specially\"\n", | ||
" \" selected to receive $1000 cash or a $2000 award.\"\n", | ||
")\n", | ||
"\n", | ||
"print(classify_review(\n", | ||
" text_1, model, tokenizer, device, max_length=120\n", | ||
"))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"id": "78472e05-cb4e-4ec4-82e8-23777aa90cf8", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"not spam\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"text_2 = (\n", | ||
" \"Hey, just wanted to check if we're still on\"\n", | ||
" \" for dinner tonight? Let me know!\"\n", | ||
")\n", | ||
"\n", | ||
"print(classify_review(\n", | ||
" text_2, model, tokenizer, device, max_length=120\n", | ||
"))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters