import random import numpy as np from memory import Memory from networks import QNet from steering_wheel import Controller from keras.callbacks import EarlyStopping class QAgent: gamma = 0.99 epsilon = 1.0 epsilon_min = 0.01 epsilon_decay = 0.9999 online_batch_size = 64 action_space = 1 name = "Q" OFFLINE_BATCHSIZE = 2048 def __init__(self, conf):#self, action_space, state_space, name): self.q = QNet(conf)#conf.env.action_space.n, conf.env.observation_space.shape[0]) self.memory = Memory() self.action_space = conf.env.action_space.n self.name = conf.name self.epsilon_decay = conf.eps_decay self.OFFLINE_BATCHSIZE = conf.offline_batchsize def get_action(self, state): if np.random.rand() <= self.epsilon: return random.randrange(self.action_space) action_values = self.q.predict(state) return np.argmax(action_values[0]) def remember(self, state, action, reward, following_state, done): self.memory.add(state, action, reward, following_state, done) def learn(self, offline=False, epochs=1): """ Learn the Q-Function. """ batch_size = self.online_batch_size if offline: batch_size = self.OFFLINE_BATCHSIZE if len(self.memory.history) < batch_size * 35: return states, actions, rewards, following_states, dones = self.memory.get_batch( batch_size) qMax = rewards + self.gamma * \ (np.amax(self.q.predict_on_batch(following_states), axis=1)) * (1-dones) y = self.q.predict_on_batch(states) idx = np.array([i for i in range(batch_size)]) y[[idx], [actions]] = qMax if offline: history = self.q.net.fit(states, y, epochs=epochs, verbose=0) loss = history.history['loss'][-1] else: loss = self.q.fit(states, y, epochs) if self.epsilon > self.epsilon_min: self.epsilon *= self.epsilon_decay return loss def save(self, path): path += "/"+ self.name print(path) self.q.save(path+'/' + self.name + '.net') self.memory.save(path+'/' + self.name + '.mem') def load(self, path, net=True, memory=True): print('Load: ' + path) if net: print('Network') self.q.load(path+'.net') if memory: self.memory.load(path+'.mem') class DQAgent(QAgent): def __init__(self, conf): super().__init__(conf) self.q2 = QNet(conf) self.name = str(self.name) + 'DBL' def learn(self, offline=False, epochs=1): if self.epsilon > self.epsilon_min: self.epsilon *= self.epsilon_decay for _ in range(2): if np.random.rand() < 0.5: temp = self.q self.q = self.q2 self.q2 = temp batch_size = self.online_batch_size if offline: batch_size = self.OFFLINE_BATCHSIZE if len(self.memory.history) < batch_size * 35: return states, actions, rewards, following_states, dones = self.memory.get_batch( batch_size) qMax = rewards + self.gamma * \ (np.amax(self.q2.predict_on_batch(following_states), axis=1)) * (1-dones) y = self.q.predict_on_batch(states) idx = np.array([i for i in range(batch_size)]) y[[idx], [actions]] = qMax if offline: callback = EarlyStopping(monitor='loss', patience=2, min_delta=0.1, restore_best_weights=True) history = self.q.net.fit(states, y, epochs=epochs, verbose=0, callbacks=[callback]) loss = history.history['loss'][-1] else: loss = self.q.fit(states, y, epochs) return loss def load(self, path, net=True, memory=True): super().load(path, net=net, memory=memory) if net: self.q2.load(path+'.net') class CarlaManual(QAgent): control = None def __init__(self, conf): super().__init__(conf) self.control = Controller() def get_action(self, state): self.control.on_update() return self.control.get_action()