CNNTrainer_torcs_agent_torcsAgent_dqn.py 2.03 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
from reinforcement_learning.agent import DqnAgent
import reinforcement_learning.environment

import CNNCreator_torcs_agent_torcsAgent_dqn
import logging
import mxnet as mx

if __name__ == "__main__":
    env_params = {
        'ros_node_name' : 'torcs_agent_torcsAgent_dqnTrainerNode',
        'state_topic' : 'preprocessor_state',
        'action_topic' : 'postprocessor_action',
        'reset_topic' : 'torcs_reset',
        'terminal_state_topic' : 'prepocessor_is_terminal'
    }
    env = reinforcement_learning.environment.RosEnvironment(**env_params)
    context = mx.cpu()
    net_creator = CNNCreator_torcs_agent_torcsAgent_dqn.CNNCreator_torcs_agent_torcsAgent_dqn()
    net_creator.construct(context)

    replay_memory_params = {
        'method':'buffer',
        'memory_size':1000000,
        'sample_size':32,
        'state_dtype':'float32',
        'action_dtype':'uint8',
        'rewards_dtype':'float32'
    }

    policy_params = {
        'method':'epsgreedy',
        'epsilon': 1,
        'min_epsilon': 0.01,
        'epsilon_decay_method': 'linear',
        'epsilon_decay': 0.0001,
    }

    agent = DqnAgent(
        network = net_creator.net,
        environment=env,
        replay_memory_params=replay_memory_params,
        policy_params=policy_params,
        state_dim=net_creator.get_input_shapes()[0],
        ctx='cpu',
        discount_factor=0.999,
        loss_function='euclidean',
        optimizer='rmsprop',
        optimizer_params={
            'learning_rate': 0.001
        },
        training_episodes=20000,
        train_interval=1,
        use_fix_target=True,
        target_update_interval=500,
        double_dqn = True,
        snapshot_interval=1000,
        agent_name='torcs_agent_torcsAgent_dqn',
        max_episode_step=999999999,
        output_directory='model',
        verbose=True,
        live_plot = True,
        make_logfile=True,
        target_score=None
    )
    train_successfull = agent.train()
    agent.save_best_network(net_creator._model_dir_ + net_creator._model_prefix_ + '_newest', epoch=0)