From 8e575e06c41605a4f1c21d44921ee33112fc9d2b Mon Sep 17 00:00:00 2001
From: Armin <armin.co@hs-bochum.de>
Date: Fri, 12 Mar 2021 22:06:15 +0100
Subject: [PATCH] Added offline batchsize option

Load both networks for double q agent
---
 .gitignore |  1 +
 agents.py  | 12 +++++++++++-
 2 files changed, 12 insertions(+), 1 deletion(-)

diff --git a/.gitignore b/.gitignore
index 31adcd1..ef8e367 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,6 +3,7 @@ __pycache__
 saved_agents
 benchmarks
 baselines
+simple
 workspace.code-workspace
 test
 tech_demo.py
diff --git a/agents.py b/agents.py
index 12ff6fe..2e7cdbe 100644
--- a/agents.py
+++ b/agents.py
@@ -20,6 +20,7 @@ class QAgent:
         self.action_space = conf.env.action_space.n
         self.name = conf.name
         self.epsilon_decay = conf.eps_decay
+        self.OFFLINE_BATCHSIZE = conf.offline_batchsize
 
     def get_action(self, state):
         if np.random.rand() <= self.epsilon:
@@ -80,7 +81,7 @@ class DQAgent(QAgent):
     def get_action(self, state):
         if np.random.rand() <= self.epsilon:
             return random.randrange(self.action_space)
-        action_values = (self.q.predict(state) + self.q2.predict(state)) / 2
+        action_values = self.q.predict(state)
         return np.argmax(action_values[0])
 
     def learn(self, offline=False, epochs=1):
@@ -110,6 +111,15 @@ class DQAgent(QAgent):
                 self.epsilon *= self.epsilon_decay
         return loss
 
+    def load(self, path, net=True, memory=True):
+        print('Load:  ' + path)
+        if net:
+            print('Network')
+            self.q.load(path+'.net')
+            self.q2.load(path+'.net')
+        if memory:
+            self.memory.load(path+'.mem')
+
 class CarlaManual(QAgent):
     control = None
 
-- 
GitLab