""" Wrapper to abstract different learning environments for an agent. """
import os
import numpy as np
from tqdm import trange
import pandas as pd
import matplotlib.pyplot as plt




def step(environment, action):
    """ Perform one iteratino in the environment. """
    following_state, reward, done, _ = environment.step(action)
    following_state = np.reshape(following_state, (1, environment.observation_space.shape[0]))
    return following_state, reward, done, _

def reset(environment):
    """ Reset the environment, and return the new state. """
    state = environment.reset()
    state = np.reshape(state, (1, environment.observation_space.shape[0]))
    return state


def one_episode(environment, agent, render, learn, conf=None, max_steps=1000):
    """ Perform one episode of the agent in the environment. """
    score = 0
    if conf.env_type == 'Carla':
        max_steps = 300
    state = reset(environment)
    for _ in range(max_steps):
        if render:
            environment.render()
        action = agent.get_action(state)
        following_state, reward, done, _ = step(environment, action)
        agent.remember(state, action, reward, following_state, done)
        score += reward
        state = following_state
        if learn:
            if conf is not None:
                agent.learn(epochs=conf.learn_epochs)
            else:
                agent.learn()
        if done:
            break
    return score

IS_SOLVED = 195
def learn_offline(agent, conf):
    """ Train the agent with its memories. """
    print('Learning with ', len(agent.memory.history), ' memories.')

    score_history = []
    avg_score_history = []
    desc_train = ''
    pbar = trange(conf.offline_episodes, desc='Loss: x')
    for i in pbar:
        loss = agent.learn(offline=True, epochs=conf.learn_epochs)
        desc = ('Loss: %05.4f' %(loss)) + desc_train
        pbar.set_description(desc)
        pbar.refresh()
        if loss > 1000:
            print("Loss exceeded 1000!!")
            exit()
        if i % conf.offline_validate_every_x_iteration == 1 and conf.offline_validate_every_x_iteration is not -1:
            agent.epsilon = agent.epsilon_min
            score = one_episode(conf.env, agent, conf.render, False, conf=conf)
            score_history.append(score)
            is_solved = np.mean(score_history[-25:])
            desc_train = (', Avg: %05.1f' %(is_solved))
            avg_score_history.append(is_solved)
            if is_solved > IS_SOLVED:
                break
            
    if conf.offline_validate_every_x_iteration is not -1:
        process_logs(avg_score_history, score_history, conf)



def run(conf):
    """ Run an agent """
    conf.name = str(conf.name) + 'on'
    # Set the exploring rate to its minimum.
    # (epsilon *greedy*)
    learn = conf.learn and conf.learn_online
    if not learn:
        conf.agent.epsilon = conf.agent.epsilon_min

    score_history = []
    avg_score_history = []
    
    pbar = trange(conf.run_episodes, desc=conf.agent.name + ' [act, avg]: [0, 0]', unit="Episodes")
    for _ in pbar:
        score = one_episode(conf.env, conf.agent, conf.render, learn, conf=conf)
        score_history.append(score)

        is_solved = np.mean(score_history[-100:])
        avg_score_history.append(is_solved)

        if is_solved > IS_SOLVED and learn:
            break
        desc = (conf.agent.name + " [act, avg]: [{0:.2f}, {1:.2f}]".format(score, is_solved))
        pbar.set_description(desc)
        pbar.refresh()
    return score_history, avg_score_history


def process_logs(avg_score_history, loss, conf):
    df = pd.DataFrame(list(zip(loss, avg_score_history)), columns=['Score', 'Average'])
    try:
        os.makedirs(conf.save_to + conf.name)
    except:
        pass
    df.to_csv(conf.save_to + conf.name + '/' + conf.name + '.csv')

    act_score = df['Score']
    avg_score = df['Average']
    plt.figure()
    plt.plot(act_score, label='Episode Score')
    plt.plot(avg_score, '--', label='Average Score')
    plt.xlabel('Episode')
    plt.ylabel('Score')
    plt.legend()
    plt.title(conf.name)
    plt.savefig(conf.save_to + conf.name + '/' + conf.name + '.png', format="png")
    df.to_csv(conf.save_to + conf.name + '/' + conf.name + '.csv')

def load_logs(file):
    df = pd.read_csv(file)
    return df["Score"], df["Average"]