Skip to content
Snippets Groups Projects
Commit ab185eb3 authored by Armin Co's avatar Armin Co
Browse files

DoubleQAgent use both networks to predict action.

parent ad48f62a
No related branches found
No related tags found
No related merge requests found
...@@ -33,22 +33,22 @@ class QAgent: ...@@ -33,22 +33,22 @@ class QAgent:
if offline: if offline:
batch_size = 4096 batch_size = 4096
if len(self.memory.history) < self.online_batch_size: if len(self.memory.history) < batch_size:
return return
states, actions, rewards, following_states, dones = self.memory.get_batch( states, actions, rewards, following_states, dones = self.memory.get_batch(
batch_size) batch_size)
targets = rewards + self.gamma * \ qMax = rewards + self.gamma * \
(np.amax(self.q.predict_on_batch(following_states), axis=1)) * (1-dones) (np.amax(self.q.predict_on_batch(following_states), axis=1)) * (1-dones)
q_targets = self.q.predict_on_batch(states) y = self.q.predict_on_batch(states)
idx = np.array([i for i in range(batch_size)]) idx = np.array([i for i in range(batch_size)])
q_targets[[idx], [actions]] = targets y[[idx], [actions]] = qMax
if offline: if offline:
history = self.q.net.fit(states, q_targets, epochs=2, verbose=0) history = self.q.net.fit(states, y, epochs=2, verbose=0)
loss = history.history['loss'][-1] loss = history.history['loss'][-1]
else: else:
loss = self.q.fit(states, q_targets, epochs) loss = self.q.fit(states, y, epochs)
if self.epsilon > self.epsilon_min: if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay self.epsilon *= self.epsilon_decay
return loss return loss
...@@ -59,10 +59,11 @@ class QAgent: ...@@ -59,10 +59,11 @@ class QAgent:
self.q.save(path+'/' + self.name + '.net') self.q.save(path+'/' + self.name + '.net')
self.memory.save(path+'/' + self.name + '.mem') self.memory.save(path+'/' + self.name + '.mem')
def load(self, path, net=False): def load(self, path, net=True, memory=True):
print(path) print(path)
if net: if net:
self.q.load(path+'.net') self.q.load(path+'.net')
if memory:
self.memory.load(path+'.mem') self.memory.load(path+'.mem')
class DQAgent(QAgent): class DQAgent(QAgent):
...@@ -70,38 +71,34 @@ class DQAgent(QAgent): ...@@ -70,38 +71,34 @@ class DQAgent(QAgent):
super().__init__(action_space, state_space, name) super().__init__(action_space, state_space, name)
self.q2 = QNet(action_space, state_space) self.q2 = QNet(action_space, state_space)
def learn(self, offline=False): def get_action(self, state):
if np.random.rand() <= self.epsilon:
return random.randrange(self.action_space)
action_values = (self.q.predict(state) + self.q2.predict(state)) / 2
return np.argmax(action_values[0])
def learn(self, offline=False):
for _ in range(2): for _ in range(2):
if np.random.rand() < 0.5: if np.random.rand() < 0.5:
temp = self.q temp = self.q
self.q = self.q2 self.q = self.q2
self.q2 = temp self.q2 = temp
batch_size = self.online_batch_size batch_size = self.online_batch_size
epochs = 1 epochs = 1
if offline: if offline:
batch_size = 4096 batch_size = 4096
if len(self.memory.history) < self.online_batch_size: if len(self.memory.history) < self.online_batch_size:
return return 0.0
states, actions, rewards, following_states, dones = self.memory.get_batch(batch_size) states, actions, rewards, following_states, dones = self.memory.get_batch(batch_size)
q_max_hat = rewards + self.gamma * (np.amax(self.q2.predict_on_batch(following_states), axis=1)) * (1-dones)
targets = rewards + self.gamma * (np.amax(self.q2.predict_on_batch(following_states), axis=1)) * (1-dones) y = self.q.predict_on_batch(states)
q_targets = self.q.predict_on_batch(states)
idx = np.array([i for i in range(batch_size)]) idx = np.array([i for i in range(batch_size)])
q_targets[[idx], [actions]] = targets y[[idx], [actions]] = q_max_hat
if offline: if offline:
history = self.q.net.fit(states, q_targets, epochs=2, verbose=0) history = self.q.net.fit(states, y, epochs=2, verbose=0)
loss = history.history['loss'][-1] loss = history.history['loss'][-1]
else: else:
loss = self.q.fit(states, q_targets, epochs) loss = self.q.fit(states, y, epochs)
if self.epsilon > self.epsilon_min: if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay self.epsilon *= self.epsilon_decay
return loss return loss
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment