Skip to content
Snippets Groups Projects
Commit e09a4dd0 authored by Armin Co's avatar Armin Co
Browse files

Refactoring

parent ab185eb3
No related branches found
No related tags found
No related merge requests found
__pycache__
saved_agents
*.png
\ No newline at end of file
......@@ -24,18 +24,18 @@ def one_episode(environment, agent, render, learn, max_steps=3000):
score = 0
state = reset(environment)
for _ in range(max_steps):
action = agent.get_action(state)
if render:
environment.render()
action = agent.get_action(state)
following_state, reward, done, _ = step(environment, action)
agent.remember(state, action, reward, following_state, done)
score += reward
state = following_state
if learn:
loss = agent.learn()
agent.learn()
if done:
if learn:
print(loss)
agent.learn(offline=True)
break
return score
......@@ -60,16 +60,19 @@ def run(environment, agent, episodes, render=True, learn=True):
score_history = []
avg_score_history = []
for _ in tqdm(range(episodes), desc="Training", unit="Episode"):
pbar = trange(episodes, desc='Score [actual, average]: [0, 0]', unit="Episodes")
for _ in pbar:
score = one_episode(environment, agent, render, learn)
score_history.append(score)
is_solved = np.mean(score_history[-100:])
is_solved = np.mean(score_history[-50:])
avg_score_history.append(is_solved)
if is_solved > 235 and learn:
if is_solved > 200 and learn:
break
print("Score [actual, average]: [{0:.2f}, {1:.2f}]\n".format(score, is_solved))
desc = ("Score [actual, average]: [{0:.2f}, {1:.2f}]".format(score, is_solved))
pbar.set_description(desc)
pbar.refresh()
return score_history, avg_score_history
......
......@@ -6,7 +6,7 @@ import os
import atexit
import gym
from agents import DQAgent
from agents import QAgent
import environment_wrapper as ew
# Allow GPU usage or force tensorflow to use the CPU.
......@@ -21,18 +21,20 @@ if __name__ == '__main__':
env = gym.make('LunarLander-v2')
# 2. Create a learning agent
marvin = DQAgent(env.action_space.n, env.observation_space.shape[0], 'double_test')
marvin = QAgent(env.action_space.n, env.observation_space.shape[0], 'from_scratch')
# (2.5) *optional* Load agent memory and/or net from disk.
marvin.load('saved_agents/large_memory/large_memory', net=False)
LOAD_MEMORIES = False
LOAD_ANN = False
marvin.load('saved_agents/agent/agent', net=LOAD_ANN, memory=LOAD_MEMORIES)
# 3. Set your configurations for the run.
RENDER = True
RENDER = False
LEARNING = True
LEARN_ONLINE = False
LEARN_OFFLINE = True
RUN_EPISODES = 200
LEARN_OFFLINE_EPOCHS = 100
LEARN_ONLINE = True
LEARN_OFFLINE = False
RUN_EPISODES = 500
LEARN_OFFLINE_EPOCHS = 500
SAVE_PATH = "./saved_agents"
# Register an *atexit* callback,
......@@ -59,4 +61,4 @@ if __name__ == '__main__':
# Show the result of the runl.
if RENDER:
ew.process_logs(avg_score, loss)
ew.process_logs(avg_score, loss, title=marvin.name)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment