CNNTrainer_torcs_agent_torcsAgent_dqn.py 3.98 KB
Newer Older
1
from reinforcement_learning.agent import DqnAgent
Nicola Gatto's avatar
Nicola Gatto committed
2
from reinforcement_learning.util import AgentSignalHandler
Nicola Gatto's avatar
Nicola Gatto committed
3
from reinforcement_learning.cnnarch_logger import ArchLogger
4 5
import reinforcement_learning.environment
import CNNCreator_torcs_agent_torcsAgent_dqn
Nicola Gatto's avatar
Nicola Gatto committed
6 7 8 9

import os
import sys
import re
Nicola Gatto's avatar
Nicola Gatto committed
10 11
import time
import numpy as np
12 13
import mxnet as mx

Nicola Gatto's avatar
Nicola Gatto committed
14

Nicola Gatto's avatar
Nicola Gatto committed
15
def resume_session(sessions_dir):
Nicola Gatto's avatar
Nicola Gatto committed
16 17
    resume_session = False
    resume_directory = None
Nicola Gatto's avatar
Nicola Gatto committed
18
    if os.path.isdir(sessions_dir):
Nicola Gatto's avatar
Nicola Gatto committed
19
        regex = re.compile(r'\d\d\d\d-\d\d-\d\d-\d\d-\d\d')
Nicola Gatto's avatar
Nicola Gatto committed
20
        dir_content = os.listdir(sessions_dir)
Nicola Gatto's avatar
Nicola Gatto committed
21 22 23
        session_files = filter(regex.search, dir_content)
        session_files.sort(reverse=True)
        for d in session_files:
Nicola Gatto's avatar
Nicola Gatto committed
24
            interrupted_session_dir = os.path.join(sessions_dir, d, '.interrupted_session')
Nicola Gatto's avatar
Nicola Gatto committed
25 26 27 28 29 30 31 32
            if os.path.isdir(interrupted_session_dir):
                resume = raw_input('Interrupted session from {} found. Do you want to resume? (y/n) '.format(d))
                if resume == 'y':
                    resume_session = True
                    resume_directory = interrupted_session_dir
                break
    return resume_session, resume_directory

Nicola Gatto's avatar
Nicola Gatto committed
33

34
if __name__ == "__main__":
Nicola Gatto's avatar
Nicola Gatto committed
35
    agent_name = 'torcs_agent_torcsAgent_dqn'
Nicola Gatto's avatar
Nicola Gatto committed
36
    # Prepare output directory and logger
Nicola Gatto's avatar
Nicola Gatto committed
37 38 39 40 41
    all_output_dir = os.path.join('model', agent_name)
    output_directory = os.path.join(
        all_output_dir,
        time.strftime('%Y-%m-%d-%H-%M-%S',
                      time.localtime(time.time())))
Nicola Gatto's avatar
Nicola Gatto committed
42 43 44 45
    ArchLogger.set_output_directory(output_directory)
    ArchLogger.set_logger_name(agent_name)
    ArchLogger.set_output_level(ArchLogger.INFO)

46
    env_params = {
Nicola Gatto's avatar
Nicola Gatto committed
47 48 49 50 51
        'ros_node_name': 'torcs_agent_torcsAgent_dqnTrainerNode',
        'state_topic': 'preprocessor_state',
        'action_topic': 'postprocessor_action',
        'reset_topic': 'torcs_reset',
        'terminal_state_topic': 'prepocessor_is_terminal',
52 53 54
    }
    env = reinforcement_learning.environment.RosEnvironment(**env_params)

Nicola Gatto's avatar
Nicola Gatto committed
55 56 57
    context = mx.cpu()
    qnet_creator = CNNCreator_torcs_agent_torcsAgent_dqn.CNNCreator_torcs_agent_torcsAgent_dqn()
    qnet_creator.construct(context)
58

Nicola Gatto's avatar
Nicola Gatto committed
59 60 61
    agent_params = {
        'environment': env,
        'replay_memory_params': {
Nicola Gatto's avatar
Nicola Gatto committed
62 63 64 65 66 67
            'method': 'buffer',
            'memory_size': 1000000,
            'sample_size': 32,
            'state_dtype': 'float32',
            'action_dtype': 'float32',
            'rewards_dtype': 'float32'
Nicola Gatto's avatar
Nicola Gatto committed
68 69 70 71 72 73 74 75 76 77
        },
        'strategy_params': {
            'method':'epsgreedy',
            'epsilon': 1,
            'min_epsilon': 0.01,
            'epsilon_decay_method': 'linear',
            'epsilon_decay': 0.0001,
        },
        'agent_name': agent_name,
        'verbose': True,
Nicola Gatto's avatar
Nicola Gatto committed
78
        'output_directory': output_directory,
Nicola Gatto's avatar
Nicola Gatto committed
79 80 81 82 83 84 85 86
        'state_dim': (5,),
        'action_dim': (30,),
        'ctx': 'cpu',
        'discount_factor': 0.999,
        'training_episodes': 20000,
        'train_interval': 1,
        'snapshot_interval': 1000,
        'max_episode_step': 999999999,
Nicola Gatto's avatar
Nicola Gatto committed
87
        'qnet':qnet_creator.networks[0],
Nicola Gatto's avatar
Nicola Gatto committed
88 89
        'use_fix_target': True,
        'target_update_interval': 500,
Nicola Gatto's avatar
Nicola Gatto committed
90
        'loss_function': 'huber',
Nicola Gatto's avatar
Nicola Gatto committed
91 92 93 94
        'optimizer': 'rmsprop',
        'optimizer_params': {
            'learning_rate': 0.001        },
        'double_dqn': True,
95 96
    }

Nicola Gatto's avatar
Nicola Gatto committed
97
    resume, resume_directory = resume_session(all_output_dir)
Nicola Gatto's avatar
Nicola Gatto committed
98

Nicola Gatto's avatar
Nicola Gatto committed
99
    if resume:
Nicola Gatto's avatar
Nicola Gatto committed
100 101
        output_directory, _ = os.path.split(resume_directory)
        ArchLogger.set_output_directory(output_directory)
Nicola Gatto's avatar
Nicola Gatto committed
102 103 104
        resume_agent_params = {
            'session_dir': resume_directory,
            'environment': env,
Nicola Gatto's avatar
Nicola Gatto committed
105
            'net': qnet_creator.networks[0],
Nicola Gatto's avatar
Nicola Gatto committed
106 107
        }
        agent = DqnAgent.resume_from_session(**resume_agent_params)
Nicola Gatto's avatar
Nicola Gatto committed
108
    else:
Nicola Gatto's avatar
Nicola Gatto committed
109
        agent = DqnAgent(**agent_params)
Nicola Gatto's avatar
Nicola Gatto committed
110 111 112 113 114 115 116

    signal_handler = AgentSignalHandler()
    signal_handler.register_agent(agent)

    train_successful = agent.train()

    if train_successful:
Nicola Gatto's avatar
Nicola Gatto committed
117
        agent.export_best_network(path=qnet_creator._model_dir_ + qnet_creator._model_prefix_ + '_0_newest', epoch=0)