diff --git a/environment_wrapper.py b/environment_wrapper.py index cf39665bc9605f7738116bed1ff7bc18706d7efd..fb800698a1c5f4290bfa6537b9be146ffa611cf3 100644 --- a/environment_wrapper.py +++ b/environment_wrapper.py @@ -31,10 +31,6 @@ class Config: def conf_to_name(self): - # 0. Allow GPU usage or force tensorflow to use the CPU. - if self.force_cpu: - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - os.environ["CUDA_VISIBLE_DEVICES"] = "" self.name = self.env_type + '_' + self.name for layer in self.net_layout: self.name += '_' + str(layer) + '_' diff --git a/networks.py b/networks.py index 04924543b256bf0568df1f95e5afbbb045d8fd30..7e16615400a0456b362f61f00f0ad74c520ae908 100644 --- a/networks.py +++ b/networks.py @@ -1,3 +1,4 @@ +import os from keras import Sequential from keras.models import load_model from keras.layers import Dense @@ -10,6 +11,9 @@ class QNet: learn_rate = 0.0005 def __init__(self, conf): + if conf.force_cpu: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = "" self.net = None self.net = Sequential() self.compile_net(conf) diff --git a/run_scripts/benchmarks.py b/run_scripts/benchmarks.py index a18117c53aef1a4ebe21d219e604da2b83633fd5..28382e658f52b16ebc231f7e3e9c09cdc63a80f8 100644 --- a/run_scripts/benchmarks.py +++ b/run_scripts/benchmarks.py @@ -23,9 +23,9 @@ smallNet.net_layout = [128, 32] smallNet.conf_to_name() smallNetSlow = copy.deepcopy(c) -smallNetSlow.name = 'SmallNetSlow' +smallNetSlow.name = 'SmallNet' smallNetSlow.net_layout = [128, 32] -smallNetSlow.learn_rate = 0.0005 +smallNetSlow.eps_decay = 0.9996 smallNetSlow.conf_to_name() smallNetDeep = copy.deepcopy(c) @@ -92,6 +92,12 @@ verryLittleNetDeep.name = 'VerryLittleNetDeep' verryLittleNetDeep.net_layout = [64, 32, 32] verryLittleNetDeep.conf_to_name() +ThisHasToBeTooBig = copy.deepcopy(c) +ThisHasToBeTooBig.name = 'ThisHasToBeTooBig' +ThisHasToBeTooBig.force_cpu = False +ThisHasToBeTooBig.net_layout = [4096, 2048] +ThisHasToBeTooBig.conf_to_name() + lun = copy.deepcopy(c) lun.run_episodes = 500 lun.name = 'NormalLunarDoubleNotSoMoreLearn' @@ -114,6 +120,8 @@ configuration = smallNetSlow # configuration = deepNetSlowLearn # configuration = smallNetDeepSlowLearn # configuration = lun +configuration = ThisHasToBeTooBig + print(configuration.name) configuration.agent = QAgent(configuration) main.run(configuration)