Skip to content
Snippets Groups Projects
Commit de2747b7 authored by Armin Co's avatar Armin Co
Browse files

Fixed learning on double QAgent.

parent fade62f0
Branches
No related tags found
No related merge requests found
__pycache__
.vscode
saved_agents
benchmarks
workspace.code-workspace
tech_demo.py
*.png
......@@ -9,6 +9,9 @@ class QAgent:
epsilon_min = 0.01
epsilon_decay = 0.9999
online_batch_size = 64
action_space = 1
name = "Q"
OFFLINE_BATCHSIZE = 2048
def __init__(self, conf):#self, action_space, state_space, name):
self.q = QNet(conf)#conf.env.action_space.n, conf.env.observation_space.shape[0])
......@@ -32,7 +35,7 @@ class QAgent:
epochs = 1
if offline:
batch_size = 2048
batch_size = self.OFFLINE_BATCHSIZE
if len(self.memory.history) < batch_size:
return
......@@ -68,9 +71,9 @@ class QAgent:
self.memory.load(path+'.mem')
class DQAgent(QAgent):
def __init__(self, action_space, state_space, name):
super().__init__(action_space, state_space, name)
self.q2 = QNet(action_space, state_space)
def __init__(self, conf):
super().__init__(conf)
self.q2 = QNet(conf)
def get_action(self, state):
if np.random.rand() <= self.epsilon:
......@@ -79,7 +82,7 @@ class DQAgent(QAgent):
return np.argmax(action_values[0])
def learn(self, offline=False):
for _ in range(3):
for _ in range(2):
if np.random.rand() < 0.5:
temp = self.q
self.q = self.q2
......@@ -87,14 +90,16 @@ class DQAgent(QAgent):
batch_size = self.online_batch_size
epochs = 1
if offline:
batch_size = 4096
batch_size = self.OFFLINE_BATCHSIZE
if len(self.memory.history) < batch_size:
return
states, actions, rewards, following_states, dones = self.memory.get_batch(batch_size)
q_max_hat = rewards + self.gamma * (np.amax(self.q2.predict_on_batch(following_states), axis=1)) * (1-dones)
states, actions, rewards, following_states, dones = self.memory.get_batch(
batch_size)
qMax = rewards + self.gamma * \
(np.amax(self.q2.predict_on_batch(following_states), axis=1)) * (1-dones)
y = self.q.predict_on_batch(states)
idx = np.array([i for i in range(batch_size)])
y[[idx], [actions]] = q_max_hat
y[[idx], [actions]] = qMax
if offline:
history = self.q.net.fit(states, y, epochs=2, verbose=0)
loss = history.history['loss'][-1]
......
import main
import environment_wrapper as ew
import gym
from carla_environment import CarlaEnvironment
import copy
import threading
c = ew.Config()
c.name = 'Base'
c.render = False
c.env = gym.make('LunarLander-v2')
c.env_type = 'Lunar'
c.net_layout = [256, 128]
c.eps_decay = 0.9996
c.learn_rate= 0.001
c.run_episodes = 300
c.save_to = 'benchmarks/'
smallNet = copy.deepcopy(c)
smallNet.name = 'SmallNet'
smallNet.net_layout = [128, 32]
smallNet.conf_to_name()
smallNetDeep = copy.deepcopy(c)
smallNetDeep.name = 'SmallNetDepp'
smallNetDeep.net_layout = [128, 32, 32]
smallNetDeep.conf_to_name()
normalNet = copy.deepcopy(c)
normalNet.name = 'NormalNet'
normalNet.net_layout = [256, 128]
normalNet.conf_to_name()
normalSlowDecay = copy.deepcopy(c)
normalSlowDecay.name = 'NormalSlowDecay'
normalSlowDecay.net_layout = [256, 128]
normalSlowDecay.eps_decay = 0.99995
normalSlowDecay.conf_to_name()
normalSlowLearn = copy.deepcopy(c)
normalSlowLearn.name = 'NormalSlowLearn'
normalSlowLearn.net_layout = [256, 128]
normalSlowLearn.learn_rate = 0.0005
normalSlowLearn.conf_to_name()
largeNet = copy.deepcopy(c)
largeNet.name = 'LargeNet'
largeNet.net_layout = [512, 256]
largeNet.conf_to_name()
deepNet = copy.deepcopy(c)
deepNet.name = 'DeppNet'
deepNet.net_layout = [256, 128, 128]
deepNet.conf_to_name()
littleNet = copy.deepcopy(c)
littleNet.name = 'LittleNet'
littleNet.net_layout = [64, 64]
littleNet.conf_to_name()
verryLittleNet = copy.deepcopy(c)
verryLittleNet.name = 'VerryLittleNet'
verryLittleNet.net_layout = [64, 32]
verryLittleNet.conf_to_name()
verryLittleNetDeep = copy.deepcopy(c)
verryLittleNetDeep.name = 'VerryLittleNetDeep'
verryLittleNetDeep.net_layout = [64, 32, 32]
verryLittleNetDeep.conf_to_name()
# configuration = smallNet
# configuration = smallNetDeep
# configuration = normalNet
# configuration = normalSlowDecay
# configuration = normalSlowLearn
# configuration = largeNet
# configuration = deepNet
# configuration = verryLittleNet
# configuration = littleNet
# configuration = verryLittleNetDeep
# main.run(configuration)
configurations = [smallNet, smallNetDeep, normalNet, normalSlowDecay, normalSlowLearn, largeNet, deepNet, verryLittleNet, littleNet, verryLittleNetDeep]
threads = []
for conf in configurations:
threads.append(threading.Thread(target=main.run, args=conf))
for thread in threads:
thread.start()
for thread in threads:
thread.join()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment