Skip to content
This repository has been archived by the owner on Jul 23, 2024. It is now read-only.

Commit

Permalink
new test file
Browse files Browse the repository at this point in the history
  • Loading branch information
rageSpin committed Feb 9, 2023
1 parent ca37699 commit fe775aa
Showing 1 changed file with 109 additions and 0 deletions.
109 changes: 109 additions & 0 deletions Reinforcement Learning/LunarLander/_.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import gym\n",
"import torch\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define Policy Network"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"class PolicyNetwork(torch.nn.Module):\n",
" def __init__(self, n=4, in_dim=128):\n",
" super(PolicyNetwork, self).__init__()\n",
" self.fc1 = torch.nn.Linear(in_dim, 128)\n",
" self.fc2 = torch.nn.Linear(128, 128)\n",
" self.fc3 = torch.nn.Linear(128, 128)\n",
" self.fc4 = torch.nn.Linear(128, n) \n",
" self.l_relu = torch.nn.LeakyReLU(0.1)\n",
"\n",
" def forward(self, x): \n",
" x = self.l_relu(self.fc1(x))\n",
" x = self.l_relu(self.fc2(x))\n",
" x = self.l_relu(self.fc3(x)) \n",
" y = self.fc4(x) \n",
" y = F.softmax(y, dim=-1) \n",
" return y \n",
"\n",
" def sample_action(self, state):\n",
" if not state is torch.Tensor:\n",
" state = torch.from_numpy(state).float().to(device)\n",
" if len(state.size()) == 1:\n",
" state = state.unsqueeze(0) \n",
" y = self(state)\n",
" dist = Categorical(y)\n",
" action = dist.sample()\n",
" log_probability = dist.log_prob(action)\n",
" return action.item(), log_probability.item()\n",
"\n",
" def best_action(self, state): \n",
" if not state is torch.Tensor:\n",
" state = torch.from_numpy(state).float().to(device) \n",
" if len(state.size()) == 1:\n",
" state = state.unsqueeze(0) \n",
" y = self(state).squeeze() \n",
" action = torch.argmax(y) \n",
" return action.item() \n",
"\n",
" def evaluate_actions(self, states, actions):\n",
" y = self(states) \n",
" dist = Categorical(y) \n",
" entropy = dist.entropy() \n",
" log_probabilities = dist.log_prob(actions) \n",
" return log_probabilities, entropy\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "4f523f7c76dd18e7ed336217f32f6f704c23c323644912475b9d3570cf04b060"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit fe775aa

Please sign in to comment.