diff --git a/.gitignore b/.gitignore index 31adcd1eea6e6b910b2362ee6452e6fc672ee573..ef8e3670a7dbd178e9615e7cdffb4ef9479c3b90 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __pycache__ saved_agents benchmarks baselines +simple workspace.code-workspace test tech_demo.py diff --git a/agents.py b/agents.py index 12ff6fe6db4b3ec57985685637e07f946cac35ff..2e7cdbe395789852117113bb99436a844ea9df42 100644 --- a/agents.py +++ b/agents.py @@ -20,6 +20,7 @@ class QAgent: self.action_space = conf.env.action_space.n self.name = conf.name self.epsilon_decay = conf.eps_decay + self.OFFLINE_BATCHSIZE = conf.offline_batchsize def get_action(self, state): if np.random.rand() <= self.epsilon: @@ -80,7 +81,7 @@ class DQAgent(QAgent): def get_action(self, state): if np.random.rand() <= self.epsilon: return random.randrange(self.action_space) - action_values = (self.q.predict(state) + self.q2.predict(state)) / 2 + action_values = self.q.predict(state) return np.argmax(action_values[0]) def learn(self, offline=False, epochs=1): @@ -110,6 +111,15 @@ class DQAgent(QAgent): self.epsilon *= self.epsilon_decay return loss + def load(self, path, net=True, memory=True): + print('Load: ' + path) + if net: + print('Network') + self.q.load(path+'.net') + self.q2.load(path+'.net') + if memory: + self.memory.load(path+'.mem') + class CarlaManual(QAgent): control = None