Skip to content

Commit

Permalink
content: kge add ipnyb to vh repo
Browse files Browse the repository at this point in the history
  • Loading branch information
AruneshSingh committed Jan 11, 2024
1 parent 7fd36fb commit 3ffef1f
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells":[{"cell_type":"code","source":["# Install required packages.\n","import os\n","import torch\n","os.environ['TORCH'] = torch.__version__\n","print(torch.__version__)\n","\n","!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html\n","!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html\n","!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git"],"metadata":{"id":"p7nHdgFVTNFU"},"id":"p7nHdgFVTNFU","execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":3,"id":"4d7bac95-721a-49ef-b71b-973a0753764d","metadata":{"id":"4d7bac95-721a-49ef-b71b-973a0753764d","executionInfo":{"status":"ok","timestamp":1702389500296,"user_tz":-60,"elapsed":7,"user":{"displayName":"Richard Kiss","userId":"07987192335421876970"}}},"outputs":[],"source":["import torch\n","import json\n","import torch.optim as optim\n","from torch_geometric.nn import DistMult\n","from torch_geometric.datasets import FB15k_237"]},{"cell_type":"code","execution_count":4,"id":"9b40010b-c87d-4b5f-9e00-d2d3f43549e3","metadata":{"id":"9b40010b-c87d-4b5f-9e00-d2d3f43549e3","executionInfo":{"status":"ok","timestamp":1702389500297,"user_tz":-60,"elapsed":5,"user":{"displayName":"Richard Kiss","userId":"07987192335421876970"}}},"outputs":[],"source":["device = \"cuda\"\n","path = \"./data\""]},{"cell_type":"code","execution_count":null,"id":"2a8e1d9c-7808-49e4-a478-9f640d2318c8","metadata":{"id":"2a8e1d9c-7808-49e4-a478-9f640d2318c8"},"outputs":[],"source":["train_data = FB15k_237(path, split='train')[0].to(device)"]},{"cell_type":"code","execution_count":9,"id":"5246723c-2ecf-41a7-8710-02c7eee864c8","metadata":{"id":"5246723c-2ecf-41a7-8710-02c7eee864c8","executionInfo":{"status":"ok","timestamp":1702389555692,"user_tz":-60,"elapsed":272,"user":{"displayName":"Richard Kiss","userId":"07987192335421876970"}}},"outputs":[],"source":["# initialize model\n","model = DistMult(\n"," num_nodes=train_data.num_nodes,\n"," num_relations=train_data.num_edge_types,\n"," hidden_channels=64\n",").to(device)\n","\n","# initialize optimizer\n","opt = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-6)\n","\n","# create data loader on the training set\n","loader = model.loader(\n"," head_index=train_data.edge_index[0],\n"," rel_type=train_data.edge_type,\n"," tail_index=train_data.edge_index[1],\n"," batch_size=2000,\n"," shuffle=True,\n",")"]},{"cell_type":"code","execution_count":null,"id":"ad1a2c7a-e188-43f6-884d-f4ccee92f7d7","metadata":{"id":"ad1a2c7a-e188-43f6-884d-f4ccee92f7d7"},"outputs":[],"source":["EPOCHS = 100\n","model.train()\n","# usual torch training loop\n","for e in range(EPOCHS):\n"," l = []\n"," for batch in loader:\n"," opt.zero_grad()\n"," loss = model.loss(*batch)\n"," l.append(loss.item())\n"," loss.backward()\n"," opt.step()\n"," print(f\"Epoch {e} loss {sum(l) / len(l):.4f}\")"]},{"cell_type":"code","execution_count":null,"id":"8ff0c01b-b0ee-4c4c-a153-768704e22155","metadata":{"id":"8ff0c01b-b0ee-4c4c-a153-768704e22155"},"outputs":[],"source":["model.to(\"cpu\").eval()"]},{"cell_type":"code","execution_count":15,"id":"5d08643b-4664-459e-8c97-0e281117f80f","metadata":{"id":"5d08643b-4664-459e-8c97-0e281117f80f","executionInfo":{"status":"ok","timestamp":1702389623886,"user_tz":-60,"elapsed":4,"user":{"displayName":"Richard Kiss","userId":"07987192335421876970"}}},"outputs":[],"source":["france = 637 # entity France\n","rel = 15 # relation /location/location/contains\n","burgundy = 638 # entity Burgundy\n","riodj = 986 # entity Rio de Janeiro\n","bnc = 7485 # Bonnie and Clyde"]},{"cell_type":"code","execution_count":null,"id":"c875579b-270b-4f67-8e80-5d7473e94fb5","metadata":{"id":"c875579b-270b-4f67-8e80-5d7473e94fb5"},"outputs":[],"source":["# Define triples\n","head_entities = torch.tensor([france, france, france], dtype=torch.long)\n","relationships = torch.tensor([rel, rel, rel], dtype=torch.long)\n","tail_entities = torch.tensor([burgundy, riodj, bnc], dtype=torch.long)\n","\n","# Score triples using the model\n","scores = model(head_entities, relationships, tail_entities)\n","print(scores.tolist())"]},{"cell_type":"code","execution_count":17,"id":"f0928d3d-f0a8-4a3b-a9f2-62ad985b764d","metadata":{"id":"f0928d3d-f0a8-4a3b-a9f2-62ad985b764d","executionInfo":{"status":"ok","timestamp":1702389627339,"user_tz":-60,"elapsed":307,"user":{"displayName":"Richard Kiss","userId":"07987192335421876970"}}},"outputs":[],"source":["guy_ritchie = 5292 # entity Guy Ritchie\n","profession = 17 # relation /people/person/profession"]},{"cell_type":"code","execution_count":18,"id":"870fea6b-80ef-4bf3-bc75-6f26e1890b09","metadata":{"id":"870fea6b-80ef-4bf3-bc75-6f26e1890b09","executionInfo":{"status":"ok","timestamp":1702389627340,"user_tz":-60,"elapsed":4,"user":{"displayName":"Richard Kiss","userId":"07987192335421876970"}}},"outputs":[],"source":["# Accessing node and relation embeddings\n","node_embeddings = model.node_emb.weight\n","relation_embeddings = model.rel_emb.weight\n","\n","# Selecting specific entities and relations\n","guy_ritchie = node_embeddings[guy_ritchie]\n","profession = relation_embeddings[profession]"]},{"cell_type":"code","execution_count":19,"id":"b4debcc1-26bd-4e8f-809c-4a7f298bd1b7","metadata":{"id":"b4debcc1-26bd-4e8f-809c-4a7f298bd1b7","executionInfo":{"status":"ok","timestamp":1702389628103,"user_tz":-60,"elapsed":3,"user":{"displayName":"Richard Kiss","userId":"07987192335421876970"}}},"outputs":[],"source":["# Creating embedding for the query based on the chosen relation and entity\n","query = guy_ritchie * profession\n","\n","# Calculating scores using vector operations\n","scores = node_embeddings @ query\n","\n","# Find the index for the top 5 scores\n","sorted_indices = scores.argsort().tolist()[-5:][::-1]\n","# Get the score for the top 5 index\n","top_5_scores = scores[sorted_indices]"]},{"cell_type":"code","execution_count":null,"id":"655d7a74-6917-44f6-8313-d8066be8bc97","metadata":{"id":"655d7a74-6917-44f6-8313-d8066be8bc97"},"outputs":[],"source":["# List top 5 hits with scores\n","list(zip(sorted_indices, top_5_scores.tolist()))"]}],"metadata":{"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.12"},"colab":{"provenance":[],"gpuType":"T4"},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":5}
2 changes: 1 addition & 1 deletion docs/use_cases/knowledge_graph_embedding.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ Moreover, these five professions are all closely related to the film industry, s

In sum, the model's performance in this scenario demonstrates its potential for **understanding concepts**, **interpreting context**, and **extracting semantic meaning**.

Here is the [complete code for this demo](https://drive.google.com/file/d/1G3tJ6Nn_6hKZ8HZGpx8OHpWwGqp_sQtF/view?usp=sharing).
Here is the [complete code for this demo](../assets/use_cases/knowledge_graph_embedding/kge_demo.ipynb).


## Comparing KGE with LLM performance on a large Knowledge Graph
Expand Down

0 comments on commit 3ffef1f

Please sign in to comment.