Skip to content
Snippets Groups Projects
Select Git revision
  • 693283b18c5742f9d1cd6a097ce9f07403c11bcf
  • main default protected
  • latest
3 results

convert.py

Blame
  • environment_wrapper.py 2.77 KiB
    """ Wrapper to abstract different learning environments for an agent. """
    import numpy as np
    
    from tqdm import tqdm
    from tqdm import trange
    import matplotlib.pyplot as plt
    
    
    def step(environment, action):
        """ Perform one iteratino in the environment. """
        following_state, reward, done, _ = environment.step(action)
        following_state = np.reshape(following_state, (1, environment.observation_space.shape[0]))
        return following_state, reward, done, _
    
    def reset(environment):
        """ Reset the environment, and return the new state. """
        state = environment.reset()
        state = np.reshape(state, (1, environment.observation_space.shape[0]))
        return state
    
    
    def one_episode(environment, agent, render, learn, max_steps=3000):
        """ Perform one episode of the agent in the environment. """
        score = 0
        state = reset(environment)
        for _ in range(max_steps):
            if render:
                environment.render()
            action = agent.get_action(state)
            following_state, reward, done, _ = step(environment, action)
            agent.remember(state, action, reward, following_state, done)
            score += reward
            state = following_state
            if learn:
                agent.learn()
            if done:
                if learn:
                    agent.learn(offline=True)
                break
        return score
    
    def learn_offline(agent, epochs=1):
        """ Train the agent with its memories. """
        print('Learning with ', len(agent.memory.history), ' memories.')
        pbar = trange(epochs, desc='Loss: x')
        for _ in pbar:
            loss = agent.learn(offline=True)
            desc = ('Loss: %05.4f' %(loss))
            pbar.set_description(desc)
            pbar.refresh()
    
    def run(environment, agent, episodes, render=True, learn=True):
        """ Run an agent """
    
        # Set the exploring rate to its minimum.
        # (epsilon *greedy*)
        if not learn:
            agent.epsilon = agent.epsilon_min
    
        score_history = []
        avg_score_history = []
        
        pbar = trange(episodes, desc='Score [actual, average]: [0, 0]', unit="Episodes")
        for _ in pbar:
            score = one_episode(environment, agent, render, learn)
            score_history.append(score)
    
            is_solved = np.mean(score_history[-50:])
            avg_score_history.append(is_solved)
    
            if is_solved > 200 and learn:
                break
            desc = ("Score [actual, average]: [{0:.2f}, {1:.2f}]".format(score, is_solved))
            pbar.set_description(desc)
            pbar.refresh()
        return score_history, avg_score_history
    
    
    def process_logs(avg_score_history, loss, title="Title"):
        """ Plot the log history """
        plt.plot([i+1 for i in range(0, len(loss), 2)], loss[::2])
        plt.plot([i+1 for i in range(0, len(avg_score_history), 2)], avg_score_history[::2], '--')
        plt.title(title)
        plt.show()
        plt.savefig(title + '.png', format="png")