import random
import numpy as np
from memory import Memory
from networks import QNet
from steering_wheel import Controller
from keras.callbacks import EarlyStopping

class QAgent:
    gamma = 0.99
    epsilon = 1.0
    epsilon_min = 0.01
    epsilon_decay = 0.9999
    online_batch_size = 64
    action_space = 1
    name = "Q"
    OFFLINE_BATCHSIZE = 2048

    def __init__(self, conf):#self, action_space, state_space, name):
        self.q = QNet(conf)#conf.env.action_space.n, conf.env.observation_space.shape[0])
        self.memory = Memory()
        self.action_space = conf.env.action_space.n
        self.name = conf.name
        self.epsilon_decay = conf.eps_decay
        self.OFFLINE_BATCHSIZE = conf.offline_batchsize

    def get_action(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_space)
        action_values = self.q.predict(state)
        return np.argmax(action_values[0])

    def remember(self, state, action, reward, following_state, done):
        self.memory.add(state, action, reward, following_state, done)

    def learn(self, offline=False, epochs=1):
        """ Learn the Q-Function. """
        batch_size = self.online_batch_size
        
        if offline:
            batch_size = self.OFFLINE_BATCHSIZE

        if len(self.memory.history) < batch_size * 35:
            return

        states, actions, rewards, following_states, dones = self.memory.get_batch(
            batch_size)
        qMax = rewards + self.gamma * \
            (np.amax(self.q.predict_on_batch(following_states), axis=1)) * (1-dones)
        y = self.q.predict_on_batch(states)
        idx = np.array([i for i in range(batch_size)])
        y[[idx], [actions]] = qMax

        if offline:
            history = self.q.net.fit(states, y, epochs=epochs, verbose=0)
            loss = history.history['loss'][-1]
        else:
            loss = self.q.fit(states, y, epochs)
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
        return loss

    def save(self, path):
        path += "/"+ self.name
        print(path)
        self.q.save(path+'/' + self.name + '.net')
        self.memory.save(path+'/' + self.name + '.mem')

    def load(self, path, net=True, memory=True):
        print('Load:  ' + path)
        if net:
            print('Network')
            self.q.load(path+'.net')
        if memory:
            self.memory.load(path+'.mem')

class DQAgent(QAgent):
    def __init__(self, conf):
        super().__init__(conf)
        self.q2 = QNet(conf)
        self.name = str(self.name) + 'DBL'

    def learn(self, offline=False, epochs=1):
        if self.epsilon > self.epsilon_min:
                self.epsilon *= self.epsilon_decay
        for _ in range(2):
            if np.random.rand() < 0.5:
                temp = self.q
                self.q = self.q2
                self.q2 = temp
            batch_size = self.online_batch_size
            if offline:
                batch_size = self.OFFLINE_BATCHSIZE
            if len(self.memory.history) < batch_size * 35:
                return
            states, actions, rewards, following_states, dones = self.memory.get_batch(
                batch_size)
            qMax = rewards + self.gamma * \
                (np.amax(self.q2.predict_on_batch(following_states), axis=1)) * (1-dones)
            y = self.q.predict_on_batch(states)
            idx = np.array([i for i in range(batch_size)])
            y[[idx], [actions]] = qMax
            if offline:
                callback = EarlyStopping(monitor='loss', patience=2, min_delta=0.1, restore_best_weights=True)
                history = self.q.net.fit(states, y, epochs=epochs, verbose=0, callbacks=[callback])
                loss = history.history['loss'][-1]
            else:
                loss = self.q.fit(states, y, epochs)
        return loss

    def load(self, path, net=True, memory=True):
        super().load(path, net=net, memory=memory)
        if net:
            self.q2.load(path+'.net')

class CarlaManual(QAgent):
    control = None

    def __init__(self, conf):
        super().__init__(conf)
        self.control = Controller()

    def get_action(self, state):
        self.control.on_update()
        return self.control.get_action()