-
Notifications
You must be signed in to change notification settings - Fork 165
/
Copy pathtrain_agent_kerasrl.py
90 lines (68 loc) · 3.52 KB
/
train_agent_kerasrl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import pickle
import numpy as np
import gym
np.random.seed(123) # set a random seed when setting up the gym environment (train_test_split)
import gym_malware
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten, ELU, Dropout, BatchNormalization
from keras.optimizers import Adam, SGD, RMSprop
# pip install keras-rl
from rl.agents.dqn import DQNAgent
from rl.agents.sarsa import SarsaAgent
from rl.policy import BoltzmannQPolicy
from rl.memory import SequentialMemory
def generate_dense_model(input_shape, layers, nb_actions):
model = Sequential()
model.add(Flatten(input_shape=input_shape))
model.add(Dropout(0.1)) # drop out the input to make model less sensitive to any 1 feature
for layer in layers:
model.add(Dense(layer))
model.add(BatchNormalization())
model.add(ELU(alpha=1.0))
model.add(Dense(nb_actions))
model.add(Activation('linear'))
print(model.summary())
return model
def train_dqn_model(layers, rounds=10000, run_test=False, use_score=False):
ENV_NAME = 'malware-score-v0' if use_score else 'malware-v0'
env = gym.make(ENV_NAME)
env.seed(123)
nb_actions = env.action_space.n
window_length = 1 # "experience" consists of where we were, where we are now
# generate a policy model
model = generate_dense_model((window_length,) + env.observation_space.shape, layers, nb_actions)
# configure and compile our agent
# BoltzmannQPolicy selects an action stochastically with a probability generated by soft-maxing Q values
policy = BoltzmannQPolicy()
# memory can help a model during training
# for this, we only consider a single malware sample (window_length=1) for each "experience"
memory = SequentialMemory(limit=32, ignore_episode_boundaries=False, window_length=window_length)
# DQN agent as described in Mnih (2013) and Mnih (2015).
# http://arxiv.org/pdf/1312.5602.pdf
# http://arxiv.org/abs/1509.06461
agent = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=16,
enable_double_dqn=True, enable_dueling_network=True, dueling_type='avg',
target_model_update=1e-2, policy=policy, batch_size=16)
# keras-rl allows one to use and built-in keras optimizer
agent.compile(RMSprop(lr=1e-3), metrics=['mae'])
# play the game. learn something!
agent.fit(env, nb_steps=rounds, visualize=False, verbose=2)
history_train = env.history
history_test = None
if run_test:
# Set up the testing environment
TEST_NAME = 'malware-score-test-v0' if use_score else 'malware-test-v0'
test_env = gym.make(TEST_NAME)
# evaluate the agent on a few episodes, drawing randomly from the test samples
agent.test(test_env, nb_episodes=100, visualize=False)
history_test = test_env.history
return agent, model, history_train, history_test
if __name__ == '__main__':
agent1, model1, history_train1, history_test1 = train_dqn_model([1024, 256], rounds=50000, run_test=True, use_score=False) # black blox
model1.save('models/dqn.h5', overwrite=True)
with open('history_blackbox.pickle', 'wb') as f:
pickle.dump(history_test1, f, pickle.HIGHEST_PROTOCOL)
agent2, model2, history_train2, history_test2 = train_dqn_model([1024, 256], rounds=50000, run_test=True, use_score=True) # allow agent to see scores
model2.save('models/dqn_score.h5', overwrite=True)
with open('history_score.pickle', 'wb') as f:
pickle.dump(history_test2, f, pickle.HIGHEST_PROTOCOL)