Skip to content
Snippets Groups Projects
Select Git revision
  • 2023ss
  • 2025ss default
  • 2024ss
  • 2022ss
  • 2021ss
  • 2020ss
  • 2019ss
  • 2018ss
8 results

shift-01.c

Blame
  • memory.py 1.14 KiB
    import random
    import pickle
    import numpy as np
    from collections import deque
    
    STATE = 0
    ACTION = 1
    REWARD = 2
    NEXT_STATE = 3
    DONE = 4
    
    class Memory:
        """ Class to store memories of an agent. """
        
        history = deque(maxlen=1000000)
    
        def add(self, state, action, reward, nextstate, done):
            self.history.append((state, action, reward, nextstate, done))
    
        def get_batch(self, batch_size):
            """ Get a random batch of samples of "batch_size" """
            batch = random.sample(self.history, batch_size)
            states = np.array([i[STATE] for i in batch])
            states = np.squeeze(states)
            actions = np.array([i[ACTION] for i in batch])
            rewards = np.array([i[REWARD] for i in batch])
            nextstates = np.array([i[NEXT_STATE] for i in batch])
            nextstates = np.squeeze(nextstates)
            dones = np.array([i[DONE] for i in batch])
            return states, actions, rewards, nextstates, dones
    
        def save(self, path):
            pickle.dump(self.history, open(path, 'wb'))
    
        def load(self, path):
            self.history = pickle.load(open(path, 'rb'))
            print('Loaded '+ str(len(self.history)) + ' memories.')