-
Notifications
You must be signed in to change notification settings - Fork 106
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7fd36fb
commit 3ffef1f
Showing
2 changed files
with
2 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"cells":[{"cell_type":"code","source":["# Install required packages.\n","import os\n","import torch\n","os.environ['TORCH'] = torch.__version__\n","print(torch.__version__)\n","\n","!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html\n","!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html\n","!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git"],"metadata":{"id":"p7nHdgFVTNFU"},"id":"p7nHdgFVTNFU","execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":3,"id":"4d7bac95-721a-49ef-b71b-973a0753764d","metadata":{"id":"4d7bac95-721a-49ef-b71b-973a0753764d","executionInfo":{"status":"ok","timestamp":1702389500296,"user_tz":-60,"elapsed":7,"user":{"displayName":"Richard Kiss","userId":"07987192335421876970"}}},"outputs":[],"source":["import torch\n","import json\n","import torch.optim as optim\n","from torch_geometric.nn import DistMult\n","from torch_geometric.datasets import FB15k_237"]},{"cell_type":"code","execution_count":4,"id":"9b40010b-c87d-4b5f-9e00-d2d3f43549e3","metadata":{"id":"9b40010b-c87d-4b5f-9e00-d2d3f43549e3","executionInfo":{"status":"ok","timestamp":1702389500297,"user_tz":-60,"elapsed":5,"user":{"displayName":"Richard Kiss","userId":"07987192335421876970"}}},"outputs":[],"source":["device = \"cuda\"\n","path = \"./data\""]},{"cell_type":"code","execution_count":null,"id":"2a8e1d9c-7808-49e4-a478-9f640d2318c8","metadata":{"id":"2a8e1d9c-7808-49e4-a478-9f640d2318c8"},"outputs":[],"source":["train_data = FB15k_237(path, split='train')[0].to(device)"]},{"cell_type":"code","execution_count":9,"id":"5246723c-2ecf-41a7-8710-02c7eee864c8","metadata":{"id":"5246723c-2ecf-41a7-8710-02c7eee864c8","executionInfo":{"status":"ok","timestamp":1702389555692,"user_tz":-60,"elapsed":272,"user":{"displayName":"Richard Kiss","userId":"07987192335421876970"}}},"outputs":[],"source":["# initialize model\n","model = DistMult(\n"," num_nodes=train_data.num_nodes,\n"," num_relations=train_data.num_edge_types,\n"," hidden_channels=64\n",").to(device)\n","\n","# initialize optimizer\n","opt = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-6)\n","\n","# create data loader on the training set\n","loader = model.loader(\n"," head_index=train_data.edge_index[0],\n"," rel_type=train_data.edge_type,\n"," tail_index=train_data.edge_index[1],\n"," batch_size=2000,\n"," shuffle=True,\n",")"]},{"cell_type":"code","execution_count":null,"id":"ad1a2c7a-e188-43f6-884d-f4ccee92f7d7","metadata":{"id":"ad1a2c7a-e188-43f6-884d-f4ccee92f7d7"},"outputs":[],"source":["EPOCHS = 100\n","model.train()\n","# usual torch training loop\n","for e in range(EPOCHS):\n"," l = []\n"," for batch in loader:\n"," opt.zero_grad()\n"," loss = model.loss(*batch)\n"," l.append(loss.item())\n"," loss.backward()\n"," opt.step()\n"," print(f\"Epoch {e} loss {sum(l) / len(l):.4f}\")"]},{"cell_type":"code","execution_count":null,"id":"8ff0c01b-b0ee-4c4c-a153-768704e22155","metadata":{"id":"8ff0c01b-b0ee-4c4c-a153-768704e22155"},"outputs":[],"source":["model.to(\"cpu\").eval()"]},{"cell_type":"code","execution_count":15,"id":"5d08643b-4664-459e-8c97-0e281117f80f","metadata":{"id":"5d08643b-4664-459e-8c97-0e281117f80f","executionInfo":{"status":"ok","timestamp":1702389623886,"user_tz":-60,"elapsed":4,"user":{"displayName":"Richard Kiss","userId":"07987192335421876970"}}},"outputs":[],"source":["france = 637 # entity France\n","rel = 15 # relation /location/location/contains\n","burgundy = 638 # entity Burgundy\n","riodj = 986 # entity Rio de Janeiro\n","bnc = 7485 # Bonnie and Clyde"]},{"cell_type":"code","execution_count":null,"id":"c875579b-270b-4f67-8e80-5d7473e94fb5","metadata":{"id":"c875579b-270b-4f67-8e80-5d7473e94fb5"},"outputs":[],"source":["# Define triples\n","head_entities = torch.tensor([france, france, france], dtype=torch.long)\n","relationships = torch.tensor([rel, rel, rel], dtype=torch.long)\n","tail_entities = torch.tensor([burgundy, riodj, bnc], dtype=torch.long)\n","\n","# Score triples using the model\n","scores = model(head_entities, relationships, tail_entities)\n","print(scores.tolist())"]},{"cell_type":"code","execution_count":17,"id":"f0928d3d-f0a8-4a3b-a9f2-62ad985b764d","metadata":{"id":"f0928d3d-f0a8-4a3b-a9f2-62ad985b764d","executionInfo":{"status":"ok","timestamp":1702389627339,"user_tz":-60,"elapsed":307,"user":{"displayName":"Richard Kiss","userId":"07987192335421876970"}}},"outputs":[],"source":["guy_ritchie = 5292 # entity Guy Ritchie\n","profession = 17 # relation /people/person/profession"]},{"cell_type":"code","execution_count":18,"id":"870fea6b-80ef-4bf3-bc75-6f26e1890b09","metadata":{"id":"870fea6b-80ef-4bf3-bc75-6f26e1890b09","executionInfo":{"status":"ok","timestamp":1702389627340,"user_tz":-60,"elapsed":4,"user":{"displayName":"Richard Kiss","userId":"07987192335421876970"}}},"outputs":[],"source":["# Accessing node and relation embeddings\n","node_embeddings = model.node_emb.weight\n","relation_embeddings = model.rel_emb.weight\n","\n","# Selecting specific entities and relations\n","guy_ritchie = node_embeddings[guy_ritchie]\n","profession = relation_embeddings[profession]"]},{"cell_type":"code","execution_count":19,"id":"b4debcc1-26bd-4e8f-809c-4a7f298bd1b7","metadata":{"id":"b4debcc1-26bd-4e8f-809c-4a7f298bd1b7","executionInfo":{"status":"ok","timestamp":1702389628103,"user_tz":-60,"elapsed":3,"user":{"displayName":"Richard Kiss","userId":"07987192335421876970"}}},"outputs":[],"source":["# Creating embedding for the query based on the chosen relation and entity\n","query = guy_ritchie * profession\n","\n","# Calculating scores using vector operations\n","scores = node_embeddings @ query\n","\n","# Find the index for the top 5 scores\n","sorted_indices = scores.argsort().tolist()[-5:][::-1]\n","# Get the score for the top 5 index\n","top_5_scores = scores[sorted_indices]"]},{"cell_type":"code","execution_count":null,"id":"655d7a74-6917-44f6-8313-d8066be8bc97","metadata":{"id":"655d7a74-6917-44f6-8313-d8066be8bc97"},"outputs":[],"source":["# List top 5 hits with scores\n","list(zip(sorted_indices, top_5_scores.tolist()))"]}],"metadata":{"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.12"},"colab":{"provenance":[],"gpuType":"T4"},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":5} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters