Skip to content
Snippets Groups Projects
Select Git revision
  • 071e75d2bc25378d1a5c2bc7e28d7715fb71c41b
  • master default protected
  • v3-modify-mail
  • snyk-fix-207483a1e839c807f95a55077e86527d
  • translations_3b5aa4f3c755059914cfa23d7d2edcde_ru
  • translations_6e4a5e377a3e50f17e6402264fdbfcc6_ru
  • translations_3b5aa4f3c755059914cfa23d7d2edcde_fa_IR
  • translations_en-yml--master_fa_IR
  • snyk-fix-7d634f2eb65555f41bf06d6af930e812
  • translations_en-yml--master_ar
  • translations_3b5aa4f3c755059914cfa23d7d2edcde_el
  • jfederico-patch-1
  • v2
  • v3
  • v1
  • release-3.1.0.2
  • release-3.1.0.1
  • release-3.1.0
  • release-2.14.8.4
  • release-3.0.9.1
  • release-3.0.9
  • release-3.0.8.1
  • release-2.14.8.3
  • release-3.0.8
  • release-3.0.7.1
  • release-2.14.8.2
  • release-3.0.7
  • release-3.0.6.1
  • release-3.0.6
  • release-3.0.5.4
  • release-3.0.5.3
  • release-2.14.8.1
  • release-3.0.5.2
  • release-3.0.5.1
  • release-3.0.5
35 results

.ruby-version

Blame
  • agents.py 4.08 KiB
    import random
    import numpy as np
    from memory import Memory
    from networks import QNet
    from steering_wheel import Controller
    
    class QAgent:
        gamma = 0.99
        epsilon = 1.0
        epsilon_min = 0.005
        epsilon_decay = 0.9999
        online_batch_size = 64
        action_space = 1
        name = "Q"
        OFFLINE_BATCHSIZE = 2048
    
        def __init__(self, conf):#self, action_space, state_space, name):
            self.q = QNet(conf)#conf.env.action_space.n, conf.env.observation_space.shape[0])
            self.memory = Memory()
            self.action_space = conf.env.action_space.n
            self.name = conf.name
            self.epsilon_decay = conf.eps_decay
    
        def get_action(self, state):
            if np.random.rand() <= self.epsilon:
                return random.randrange(self.action_space)
            action_values = self.q.predict(state)
            return np.argmax(action_values[0])
    
        def remember(self, state, action, reward, following_state, done):
            self.memory.add(state, action, reward, following_state, done)
    
        def learn(self, offline=False, epochs=1):
            """ Learn the Q-Function. """
            batch_size = self.online_batch_size
            
            if offline:
                batch_size = self.OFFLINE_BATCHSIZE
    
            if len(self.memory.history) < batch_size:
                return
    
            states, actions, rewards, following_states, dones = self.memory.get_batch(
                batch_size)
            qMax = rewards + self.gamma * \
                (np.amax(self.q.predict_on_batch(following_states), axis=1)) * (1-dones)
            y = self.q.predict_on_batch(states)
            idx = np.array([i for i in range(batch_size)])
            y[[idx], [actions]] = qMax
    
            if offline:
                history = self.q.net.fit(states, y, epochs=epochs, verbose=0)
                loss = history.history['loss'][-1]
            else:
                loss = self.q.fit(states, y, epochs)
            if self.epsilon > self.epsilon_min:
                self.epsilon *= self.epsilon_decay
            return loss
    
        def save(self, path):
            path += "/"+ self.name
            print(path)
            self.q.save(path+'/' + self.name + '.net')
            self.memory.save(path+'/' + self.name + '.mem')
    
        def load(self, path, net=True, memory=True):
            print('Load:  ' + path)
            if net:
                print('Network')
                self.q.load(path+'.net')
            if memory:
                self.memory.load(path+'.mem')
    
    class DQAgent(QAgent):
        def __init__(self, conf):
            super().__init__(conf)
            self.q2 = QNet(conf)
            self.name = 'D_' + str(self.name)
    
        def get_action(self, state):
            if np.random.rand() <= self.epsilon:
                return random.randrange(self.action_space)
            action_values = (self.q.predict(state) + self.q2.predict(state)) / 2
            return np.argmax(action_values[0])
    
        def learn(self, offline=False, epochs=1):
            for _ in range(2):
                if np.random.rand() < 0.5:
                    temp = self.q
                    self.q = self.q2
                    self.q2 = temp
                batch_size = self.online_batch_size
                if offline:
                    batch_size = self.OFFLINE_BATCHSIZE
                if len(self.memory.history) < batch_size:
                    return
                states, actions, rewards, following_states, dones = self.memory.get_batch(
                    batch_size)
                qMax = rewards + self.gamma * \
                    (np.amax(self.q2.predict_on_batch(following_states), axis=1)) * (1-dones)
                y = self.q.predict_on_batch(states)
                idx = np.array([i for i in range(batch_size)])
                y[[idx], [actions]] = qMax
                if offline:
                    history = self.q.net.fit(states, y, epochs=epochs, verbose=0)
                    loss = history.history['loss'][-1]
                else:
                    loss = self.q.fit(states, y, epochs)
                if self.epsilon > self.epsilon_min:
                    self.epsilon *= self.epsilon_decay
            return loss
    
    class CarlaManual(QAgent):
        control = None
    
        def __init__(self, conf):
            super().__init__(conf)
            self.control = Controller()
    
        def get_action(self, state):
            self.control.on_update()
            return self.control.get_action()