Skip to content
Snippets Groups Projects
Select Git revision
  • ba4d8e617d547e5dda474ebdacc29a338c674ead
  • master default protected
  • change_modified_reward_v0
  • feature_carla_szenarios
  • develop_moreSensorsInCarla
  • feature_carlaSupport
  • LearningEnvironment
7 results

main.py

Blame
  • Armin's avatar
    Armin Co authored
    Make sure you change the name of the agent, otherwise the old data may be overriden.
    ba4d8e61
    History
    main.py 1.52 KiB
    """
    Run your desired environment and agent configuration.
    """
    
    import os
    import atexit
    import gym
    from agents import QAgent
    
    import environment_wrapper as ew
    
    def run(conf):
    
        # 1. Create a learning agent
        marvin = conf.agent
    
        # (2.) *optional* Load agent memory and/or net from disk.
        if conf.load_ann or conf.load_mem:
            marvin.load(conf.save_to + conf.load_from + '/' + conf.load_from, net=conf.load_ann, memory=conf.load_mem)
    
        # 3. Set your configurations for the run.
        # Register an *atexit* callback,
        # to store the corrent result of the agent
        # if the program is interrupted.
        atexit.register(marvin.save, conf.save_to)
    
        # Offline training of the agent with
        # previous collected and saved memories.
        if conf.learn_offline and conf.learn:
            ew.learn_offline(marvin, conf)
    
        # Run the agent in the environment for the
        # number of specified epochs. Either to
        # verify the performance of the agent or
        # to train the agent.
        _LEARN = conf.learn_online and conf.learn
        loss, avg_score = ew.run(conf.env, marvin, conf.run_episodes, render=conf.render, learn=_LEARN, conf=conf)
    
        # Save the final training result of the agent.
        marvin.save(conf.save_to)
        
        ew.process_logs(avg_score, loss, conf)
    
        if conf.env_type == 'Carla':
            conf.env.world.destroy()
    
    if __name__ == '__main__':
        conf = ew.Config()
        conf.render = True
        conf.env = gym.make('LunarLander-v2')
        conf.env_type = 'Lunar'
        conf.conf_to_name()
        conf.agent = QAgent(conf)
        run(conf)