Skip to content
Snippets Groups Projects
Commit 99dec5eb authored by Armin Co's avatar Armin Co
Browse files

Monitoring offline training

parent b39317dc
Branches
No related tags found
No related merge requests found
...@@ -77,24 +77,32 @@ def one_episode(environment, agent, render, learn, conf=None, max_steps=1000): ...@@ -77,24 +77,32 @@ def one_episode(environment, agent, render, learn, conf=None, max_steps=1000):
def learn_offline(agent, conf): def learn_offline(agent, conf):
""" Train the agent with its memories. """ """ Train the agent with its memories. """
print('Learning with ', len(agent.memory.history), ' memories.') print('Learning with ', len(agent.memory.history), ' memories.')
agent.epsilon = agent.epsilon_min
score_history = []
avg_score_history = []
desc_train = ''
pbar = trange(conf.offline_epochs, desc='Loss: x') pbar = trange(conf.offline_epochs, desc='Loss: x')
for i in pbar: for i in pbar:
loss = agent.learn(offline=True, epochs=conf.learn_iterations) loss = agent.learn(offline=True, epochs=conf.learn_iterations)
desc = ('Loss: %05.4f' %(loss)) desc = ('Loss: %05.4f' %(loss)) + desc_train
pbar.set_description(desc) pbar.set_description(desc)
pbar.refresh() pbar.refresh()
if i % conf.offline_validate_every_x_iteration == 0 and conf.offline_validate_every_x_iteration is not -1: if i % conf.offline_validate_every_x_iteration == 1 and conf.offline_validate_every_x_iteration is not -1:
score, avg = run(conf.env, conf.agent, 1, render=conf.render, learn=False, conf=conf) score = one_episode(conf.env, agent, conf.render, False, conf=conf)
conf.name += '1' score_history.append(score)
process_logs(avg, score, conf) is_solved = np.mean(score_history[-25:])
if avg[-1] > IS_SOLVED: desc_train = (', Avg: %05.1f' %(is_solved))
avg_score_history.append(is_solved)
if is_solved > IS_SOLVED:
break break
process_logs(avg_score_history, score_history, conf)
def run(environment, agent, episodes, render=True, learn=True, conf=None): def run(environment, agent, episodes, render=True, learn=True, conf=None):
""" Run an agent """ """ Run an agent """
conf.name += 'on'
# Set the exploring rate to its minimum. # Set the exploring rate to its minimum.
# (epsilon *greedy*) # (epsilon *greedy*)
if not learn: if not learn:
......
...@@ -5,7 +5,7 @@ from keras.layers import Dense ...@@ -5,7 +5,7 @@ from keras.layers import Dense
from keras.optimizers import Adam from keras.optimizers import Adam
from keras.activations import relu, linear from keras.activations import relu, linear
from keras.regularizers import l2 from keras.regularizers import l2
from keras.callbacks import EarlyStopping
class QNet: class QNet:
learn_rate = 0.0005 learn_rate = 0.0005
...@@ -34,7 +34,8 @@ class QNet: ...@@ -34,7 +34,8 @@ class QNet:
self, states): return self.net.predict_on_batch(states) self, states): return self.net.predict_on_batch(states)
def fit(self, X, Y, epochs=1, verbose=0): def fit(self, X, Y, epochs=1, verbose=0):
history = self.net.fit(X, Y, epochs=epochs, verbose=verbose) callback = EarlyStopping(monitor='loss', patience=3)
history = self.net.fit(X, Y, epochs=epochs, verbose=verbose, callbacks=[callback])
return history.history['loss'][-1] return history.history['loss'][-1]
def save(self, path): def save(self, path):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment