Skip to content
Snippets Groups Projects
Commit 662c1621 authored by Armin Co's avatar Armin Co
Browse files

Update to scriptability.

parent de2747b7
Branches
Tags
No related merge requests found
""" Wrapper to abstract different learning environments for an agent. """
import os
import numpy as np
from tqdm import trange
import pandas as pd
import matplotlib.pyplot as plt
......@@ -9,6 +9,7 @@ class Config:
render = False
force_cpu = True
env = None
agent = None
env_type = 'Lunar'
name = 'ConfigTest'
learn = True
......@@ -24,7 +25,12 @@ class Config:
load_from = 'agnt'
save_to = 'saved_agents/'
def conf_to_name(self):
# 0. Allow GPU usage or force tensorflow to use the CPU.
if self.force_cpu:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
self.name += self.env_type
for layer in self.net_layout:
self.name += '_' + str(layer) + '_'
......@@ -85,15 +91,15 @@ def run(environment, agent, episodes, render=True, learn=True):
score_history = []
avg_score_history = []
pbar = trange(episodes, desc='Score [actual, average]: [0, 0]', unit="Episodes")
pbar = trange(episodes, desc=agent.name + ' [actual, average]: [0, 0]', unit="Episodes")
for _ in pbar:
score = one_episode(environment, agent, render, learn)
score_history.append(score)
is_solved = np.mean(score_history[-50:])
is_solved = np.mean(score_history[-100:])
avg_score_history.append(is_solved)
if is_solved > 200 and learn:
if is_solved > 195 and learn:
break
desc = ("Score [actual, average]: [{0:.2f}, {1:.2f}]".format(score, is_solved))
pbar.set_description(desc)
......
......@@ -6,19 +6,14 @@ import os
import atexit
import gym
from agents import QAgent as QAgent
import environment_wrapper as ew
def run(conf):
# 0. Allow GPU usage or force tensorflow to use the CPU.
if conf.force_cpu:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# 2. Create a learning agent
marvin = QAgent(conf) #conf.env.action_space.n, conf.env.observation_space.shape[0], conf.name)
# 1. Create a learning agent
marvin = conf.agent
# (2.5) *optional* Load agent memory and/or net from disk.
# (2.) *optional* Load agent memory and/or net from disk.
if conf.load_ann or conf.load_mem:
marvin.load(conf.save_to + conf.load_from + '/' + conf.load_from, net=conf.load_ann, memory=conf.load_mem)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment