Select Git revision
environment_wrapper.py 2.77 KiB
""" Wrapper to abstract different learning environments for an agent. """
import numpy as np
from tqdm import tqdm
from tqdm import trange
import matplotlib.pyplot as plt
def step(environment, action):
""" Perform one iteratino in the environment. """
following_state, reward, done, _ = environment.step(action)
following_state = np.reshape(following_state, (1, environment.observation_space.shape[0]))
return following_state, reward, done, _
def reset(environment):
""" Reset the environment, and return the new state. """
state = environment.reset()
state = np.reshape(state, (1, environment.observation_space.shape[0]))
return state
def one_episode(environment, agent, render, learn, max_steps=3000):
""" Perform one episode of the agent in the environment. """
score = 0
state = reset(environment)
for _ in range(max_steps):
if render:
environment.render()
action = agent.get_action(state)
following_state, reward, done, _ = step(environment, action)
agent.remember(state, action, reward, following_state, done)
score += reward
state = following_state
if learn:
agent.learn()
if done:
if learn:
agent.learn(offline=True)
break
return score
def learn_offline(agent, epochs=1):
""" Train the agent with its memories. """
print('Learning with ', len(agent.memory.history), ' memories.')
pbar = trange(epochs, desc='Loss: x')
for _ in pbar:
loss = agent.learn(offline=True)
desc = ('Loss: %05.4f' %(loss))
pbar.set_description(desc)
pbar.refresh()
def run(environment, agent, episodes, render=True, learn=True):
""" Run an agent """
# Set the exploring rate to its minimum.
# (epsilon *greedy*)
if not learn:
agent.epsilon = agent.epsilon_min
score_history = []
avg_score_history = []
pbar = trange(episodes, desc='Score [actual, average]: [0, 0]', unit="Episodes")
for _ in pbar:
score = one_episode(environment, agent, render, learn)
score_history.append(score)
is_solved = np.mean(score_history[-50:])
avg_score_history.append(is_solved)
if is_solved > 200 and learn:
break
desc = ("Score [actual, average]: [{0:.2f}, {1:.2f}]".format(score, is_solved))
pbar.set_description(desc)
pbar.refresh()
return score_history, avg_score_history
def process_logs(avg_score_history, loss, title="Title"):
""" Plot the log history """
plt.plot([i+1 for i in range(0, len(loss), 2)], loss[::2])
plt.plot([i+1 for i in range(0, len(avg_score_history), 2)], avg_score_history[::2], '--')
plt.title(title)
plt.show()
plt.savefig(title + '.png', format="png")