Select Git revision
chardev-1.c
Forked from
Peter Gerwinski / bs
Source project has a limited visibility.
memory.py 1.14 KiB
import random
import pickle
import numpy as np
from collections import deque
STATE = 0
ACTION = 1
REWARD = 2
NEXT_STATE = 3
DONE = 4
class Memory:
""" Class to store memories of an agent. """
history = deque(maxlen=1000000)
def add(self, state, action, reward, nextstate, done):
self.history.append((state, action, reward, nextstate, done))
def get_batch(self, batch_size):
""" Get a random batch of samples of "batch_size" """
batch = random.sample(self.history, batch_size)
states = np.array([i[STATE] for i in batch])
states = np.squeeze(states)
actions = np.array([i[ACTION] for i in batch])
rewards = np.array([i[REWARD] for i in batch])
nextstates = np.array([i[NEXT_STATE] for i in batch])
nextstates = np.squeeze(nextstates)
dones = np.array([i[DONE] for i in batch])
return states, actions, rewards, nextstates, dones
def save(self, path):
pickle.dump(self.history, open(path, 'wb'))
def load(self, path):
self.history = pickle.load(open(path, 'rb'))
print('Loaded '+ str(len(self.history)) + ' memories.')