Skip to content
Snippets Groups Projects
Commit b376787b authored by Armin Co's avatar Armin Co
Browse files

Update on environment wrapper

Updated learning behaviour for double Q agent and fixed batch size for offline learning.
parent e09a4dd0
No related branches found
No related tags found
No related merge requests found
...@@ -78,7 +78,7 @@ class DQAgent(QAgent): ...@@ -78,7 +78,7 @@ class DQAgent(QAgent):
return np.argmax(action_values[0]) return np.argmax(action_values[0])
def learn(self, offline=False): def learn(self, offline=False):
for _ in range(2): for _ in range(3):
if np.random.rand() < 0.5: if np.random.rand() < 0.5:
temp = self.q temp = self.q
self.q = self.q2 self.q = self.q2
...@@ -87,8 +87,8 @@ class DQAgent(QAgent): ...@@ -87,8 +87,8 @@ class DQAgent(QAgent):
epochs = 1 epochs = 1
if offline: if offline:
batch_size = 4096 batch_size = 4096
if len(self.memory.history) < self.online_batch_size: if len(self.memory.history) < batch_size:
return 0.0 return
states, actions, rewards, following_states, dones = self.memory.get_batch(batch_size) states, actions, rewards, following_states, dones = self.memory.get_batch(batch_size)
q_max_hat = rewards + self.gamma * (np.amax(self.q2.predict_on_batch(following_states), axis=1)) * (1-dones) q_max_hat = rewards + self.gamma * (np.amax(self.q2.predict_on_batch(following_states), axis=1)) * (1-dones)
y = self.q.predict_on_batch(states) y = self.q.predict_on_batch(states)
......
...@@ -76,11 +76,11 @@ def run(environment, agent, episodes, render=True, learn=True): ...@@ -76,11 +76,11 @@ def run(environment, agent, episodes, render=True, learn=True):
return score_history, avg_score_history return score_history, avg_score_history
def process_logs(avg_score_history, loss, title="Title"): def process_logs(avg_score_history, loss, title="Title", render=False):
""" Plot the log history """ """ Plot the log history """
plt.plot([i+1 for i in range(0, len(loss), 2)], loss[::2]) plt.plot([i+1 for i in range(0, len(loss), 2)], loss[::2])
plt.plot([i+1 for i in range(0, len(avg_score_history), 2)], avg_score_history[::2], '--') plt.plot([i+1 for i in range(0, len(avg_score_history), 2)], avg_score_history[::2], '--')
plt.title(title) plt.title(title)
plt.show()
plt.savefig(title + '.png', format="png") plt.savefig(title + '.png', format="png")
if render:
plt.show()
...@@ -6,7 +6,7 @@ import os ...@@ -6,7 +6,7 @@ import os
import atexit import atexit
import gym import gym
from agents import QAgent from agents import DQAgent as QAgent
import environment_wrapper as ew import environment_wrapper as ew
# Allow GPU usage or force tensorflow to use the CPU. # Allow GPU usage or force tensorflow to use the CPU.
...@@ -21,12 +21,14 @@ if __name__ == '__main__': ...@@ -21,12 +21,14 @@ if __name__ == '__main__':
env = gym.make('LunarLander-v2') env = gym.make('LunarLander-v2')
# 2. Create a learning agent # 2. Create a learning agent
marvin = QAgent(env.action_space.n, env.observation_space.shape[0], 'from_scratch') marvin = QAgent(env.action_space.n, env.observation_space.shape[0], 'FromScratchDouble')
# (2.5) *optional* Load agent memory and/or net from disk. # (2.5) *optional* Load agent memory and/or net from disk.
LOAD_MEMORIES = False agnt = 'agent'
LOAD_ANN = False LOAD_ANN = False
marvin.load('saved_agents/agent/agent', net=LOAD_ANN, memory=LOAD_MEMORIES) LOAD_MEMORIES = False
if LOAD_ANN or LOAD_MEMORIES:
marvin.load('saved_agents/' + agnt + '/' + agnt, net=LOAD_ANN, memory=LOAD_MEMORIES)
# 3. Set your configurations for the run. # 3. Set your configurations for the run.
RENDER = False RENDER = False
...@@ -60,5 +62,4 @@ if __name__ == '__main__': ...@@ -60,5 +62,4 @@ if __name__ == '__main__':
marvin.save(SAVE_PATH) marvin.save(SAVE_PATH)
# Show the result of the runl. # Show the result of the runl.
if RENDER: ew.process_logs(avg_score, loss, title=marvin.name, render=RENDER)
ew.process_logs(avg_score, loss, title=marvin.name)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment