From 99dec5eb6ddc17e5cdbab2661fa9d22c02b159b6 Mon Sep 17 00:00:00 2001
From: Armin <armin.co@hs-bochum.de>
Date: Mon, 8 Mar 2021 09:52:57 +0100
Subject: [PATCH] Monitoring offline training

---
 environment_wrapper.py | 22 +++++++++++++++-------
 networks.py            |  5 +++--
 2 files changed, 18 insertions(+), 9 deletions(-)

diff --git a/environment_wrapper.py b/environment_wrapper.py
index fb80069..e55058b 100644
--- a/environment_wrapper.py
+++ b/environment_wrapper.py
@@ -77,24 +77,32 @@ def one_episode(environment, agent, render, learn, conf=None, max_steps=1000):
 def learn_offline(agent, conf):
     """ Train the agent with its memories. """
     print('Learning with ', len(agent.memory.history), ' memories.')
+    agent.epsilon = agent.epsilon_min
+
+    score_history = []
+    avg_score_history = []
+    desc_train = ''
     pbar = trange(conf.offline_epochs, desc='Loss: x')
     for i in pbar:
         loss = agent.learn(offline=True, epochs=conf.learn_iterations)
-        desc = ('Loss: %05.4f' %(loss))
+        desc = ('Loss: %05.4f' %(loss)) + desc_train
         pbar.set_description(desc)
         pbar.refresh()
-        if i % conf.offline_validate_every_x_iteration == 0 and conf.offline_validate_every_x_iteration is not -1:
-            score, avg = run(conf.env, conf.agent, 1, render=conf.render, learn=False, conf=conf)
-            conf.name += '1'
-            process_logs(avg, score, conf)
-            if avg[-1] > IS_SOLVED:
+        if i % conf.offline_validate_every_x_iteration == 1 and conf.offline_validate_every_x_iteration is not -1:
+            score = one_episode(conf.env, agent, conf.render, False, conf=conf)
+            score_history.append(score)
+            is_solved = np.mean(score_history[-25:])
+            desc_train = (', Avg: %05.1f' %(is_solved))
+            avg_score_history.append(is_solved)
+            if is_solved > IS_SOLVED:
                 break
+    process_logs(avg_score_history, score_history, conf)
 
 
 
 def run(environment, agent, episodes, render=True, learn=True, conf=None):
     """ Run an agent """
-
+    conf.name += 'on'
     # Set the exploring rate to its minimum.
     # (epsilon *greedy*)
     if not learn:
diff --git a/networks.py b/networks.py
index 7e16615..32b9461 100644
--- a/networks.py
+++ b/networks.py
@@ -5,7 +5,7 @@ from keras.layers import Dense
 from keras.optimizers import Adam
 from keras.activations import relu, linear
 from keras.regularizers import l2
-
+from keras.callbacks import EarlyStopping
 class QNet:
     
     learn_rate = 0.0005
@@ -34,7 +34,8 @@ class QNet:
         self, states): return self.net.predict_on_batch(states)
 
     def fit(self, X, Y, epochs=1, verbose=0):
-        history = self.net.fit(X, Y, epochs=epochs, verbose=verbose)
+        callback = EarlyStopping(monitor='loss', patience=3)
+        history = self.net.fit(X, Y, epochs=epochs, verbose=verbose, callbacks=[callback])
         return history.history['loss'][-1]
 
     def save(self, path):
-- 
GitLab