Skip to content
Snippets Groups Projects
Select Git revision
  • 4764b70a146a1bdbc06b901287abe2f96a8b4323
  • master default protected
  • change_modified_reward_v0
  • feature_carla_szenarios
  • develop_moreSensorsInCarla
  • feature_carlaSupport
  • LearningEnvironment
7 results

main.py

Blame
  • Armin's avatar
    Armin Co authored
    4764b70a
    History
    main.py 1.79 KiB
    """
    Run your desired environment and agent configuration.
    """
    
    import os
    import atexit
    import gym
    
    from agents import QAgent as QAgent
    import environment_wrapper as ew
    
    def run(conf):
        # 0. Allow GPU usage or force tensorflow to use the CPU.
        if conf.force_cpu:
            os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
            os.environ["CUDA_VISIBLE_DEVICES"] = ""
    
        # 2. Create a learning agent
        marvin = QAgent(conf) #conf.env.action_space.n, conf.env.observation_space.shape[0], conf.name)
    
        # (2.5) *optional* Load agent memory and/or net from disk.
        if conf.load_ann or conf.load_mem:
            marvin.load(conf.save_to + conf.load_from + '/' + conf.load_from, net=conf.load_ann, memory=conf.load_mem)
    
        # 3. Set your configurations for the run.
        # Register an *atexit* callback,
        # to store the corrent result of the agent
        # if the program is interrupted.
        if conf.learn:
            atexit.register(marvin.save, conf.save_to)
    
        # Offline training of the agent with
        # previous collected and saved memories.
        if conf.learn_offline and conf.learn:
            ew.learn_offline(marvin, epochs=conf.offline_epochs)
    
        # Run the agent in the environment for the
        # number of specified epochs. Either to
        # verify the performance of the agent or
        # to train the agent.
        _LEARN = conf.learn_online and conf.learn
        loss, avg_score = ew.run(conf.env, marvin, conf.run_episodes, render=conf.render, learn=_LEARN)
    
        # Save the final training result of the agent.
        if conf.learn:
            marvin.save(conf.save_to)
        
        ew.process_logs(avg_score, loss, conf)
    
        if conf.env_type == 'Carla':
            conf.env.world.destroy()
    
    if __name__ == '__main__':
        conf = ew.Config()
        conf.env = gym.make('LunarLander-v2')
        conf.env_type = 'Lunar'
        conf.conf_to_name()
        run(conf)