From 8e575e06c41605a4f1c21d44921ee33112fc9d2b Mon Sep 17 00:00:00 2001 From: Armin <armin.co@hs-bochum.de> Date: Fri, 12 Mar 2021 22:06:15 +0100 Subject: [PATCH] Added offline batchsize option Load both networks for double q agent --- .gitignore | 1 + agents.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 31adcd1..ef8e367 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __pycache__ saved_agents benchmarks baselines +simple workspace.code-workspace test tech_demo.py diff --git a/agents.py b/agents.py index 12ff6fe..2e7cdbe 100644 --- a/agents.py +++ b/agents.py @@ -20,6 +20,7 @@ class QAgent: self.action_space = conf.env.action_space.n self.name = conf.name self.epsilon_decay = conf.eps_decay + self.OFFLINE_BATCHSIZE = conf.offline_batchsize def get_action(self, state): if np.random.rand() <= self.epsilon: @@ -80,7 +81,7 @@ class DQAgent(QAgent): def get_action(self, state): if np.random.rand() <= self.epsilon: return random.randrange(self.action_space) - action_values = (self.q.predict(state) + self.q2.predict(state)) / 2 + action_values = self.q.predict(state) return np.argmax(action_values[0]) def learn(self, offline=False, epochs=1): @@ -110,6 +111,15 @@ class DQAgent(QAgent): self.epsilon *= self.epsilon_decay return loss + def load(self, path, net=True, memory=True): + print('Load: ' + path) + if net: + print('Network') + self.q.load(path+'.net') + self.q2.load(path+'.net') + if memory: + self.memory.load(path+'.mem') + class CarlaManual(QAgent): control = None -- GitLab