Skip to content

Commit

Permalink
Now it is uses Gpu only
Browse files Browse the repository at this point in the history
I have transfer all variable in GPU which make them faster for parallel computation
  • Loading branch information
SwayamThapliyal committed Nov 3, 2023
1 parent 557dca3 commit d0f3efc
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Game.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Direction(Enum):
BLACK = (0,0,0)

BLOCK_SIZE = 20
SPEED = 40
SPEED = 80

class SnakeGameAI:

Expand Down
Binary file modified __pycache__/Game.cpython-311.pyc
Binary file not shown.
Binary file modified __pycache__/model.cpython-311.pyc
Binary file not shown.
10 changes: 6 additions & 4 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from model import Linear_QNet, QTrainer
from helper import plot

MAX_MEMORY = 100_000
BATCH_SIZE = 1000
MAX_MEMORY = 200_000
BATCH_SIZE = 2000
LR = 0.001

class Agent:
Expand All @@ -18,6 +18,7 @@ def __init__(self):
self.gamma = 0.9 # discount rate
self.memory = deque(maxlen=MAX_MEMORY) # popleft()
self.model = Linear_QNet(11, 256, 3)
self.model.to('cuda')
self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma)


Expand Down Expand Up @@ -65,7 +66,7 @@ def get_state(self, game):
game.food.y > game.head.y # food down
]

return np.array(state, dtype=int)
return np.array(state, dtype = int)

def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done)) # popleft if MAX_MEMORY is reached
Expand All @@ -92,7 +93,8 @@ def get_action(self, state):
move = random.randint(0, 2)
final_move[move] = 1
else:
state0 = torch.tensor(state, dtype=torch.float)
#print(state)
state0 = torch.tensor(state, dtype=torch.float,device='cuda')
prediction = self.model(state0)
move = torch.argmax(prediction).item()
final_move[move] = 1
Expand Down
14 changes: 10 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@ class Linear_QNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.linear1 = nn.Linear(input_size, hidden_size)
self.linear2 = nn.Linear(hidden_size, output_size)
self.linear2 = nn.Linear(hidden_size,hidden_size)
self.linear3 = nn.Linear(hidden_size,hidden_size)
# self.linear4 = nn.Linear(hidden_size,hidden_size)
self.linear5 = nn.Linear(hidden_size, output_size)

def forward(self, x):
x = F.relu(self.linear1(x))
x = self.linear2(x)
x = F.relu(self.linear2(x))
x = F.relu(self.linear3(x))
#x = F.relu(self.linear4(x))
x = self.linear5(x)
return x

def save(self, file_name='model.pth'):
Expand All @@ -33,8 +39,8 @@ def __init__(self, model, lr, gamma):
self.criterion = nn.MSELoss()

def train_step(self, state, action, reward, next_state, done):
state = torch.tensor(state, dtype=torch.float)
next_state = torch.tensor(next_state, dtype=torch.float)
state = torch.tensor(state, dtype=torch.float,device ='cuda')
next_state = torch.tensor(next_state, dtype=torch.float,device= 'cuda')
action = torch.tensor(action, dtype=torch.long)
reward = torch.tensor(reward, dtype=torch.float)
# (n, x)
Expand Down
Binary file modified model/model.pth
Binary file not shown.

0 comments on commit d0f3efc

Please sign in to comment.