CNNTrainer_cartpole_master_dqn.py 3.69 KB
Newer Older
1
from reinforcement_learning.agent import DqnAgent
Nicola Gatto's avatar
Nicola Gatto committed
2
from reinforcement_learning.util import AgentSignalHandler
Nicola Gatto's avatar
Nicola Gatto committed
3
from reinforcement_learning.cnnarch_logger import ArchLogger
4 5
import reinforcement_learning.environment
import CNNCreator_cartpole_master_dqn
Nicola Gatto's avatar
Nicola Gatto committed
6 7 8 9

import os
import sys
import re
Nicola Gatto's avatar
Nicola Gatto committed
10 11
import time
import numpy as np
12 13
import mxnet as mx

Nicola Gatto's avatar
Nicola Gatto committed
14

Nicola Gatto's avatar
Nicola Gatto committed
15
def resume_session(sessions_dir):
Nicola Gatto's avatar
Nicola Gatto committed
16 17
    resume_session = False
    resume_directory = None
Nicola Gatto's avatar
Nicola Gatto committed
18
    if os.path.isdir(sessions_dir):
Nicola Gatto's avatar
Nicola Gatto committed
19
        regex = re.compile(r'\d\d\d\d-\d\d-\d\d-\d\d-\d\d')
Nicola Gatto's avatar
Nicola Gatto committed
20
        dir_content = os.listdir(sessions_dir)
Nicola Gatto's avatar
Nicola Gatto committed
21 22 23
        session_files = filter(regex.search, dir_content)
        session_files.sort(reverse=True)
        for d in session_files:
Nicola Gatto's avatar
Nicola Gatto committed
24
            interrupted_session_dir = os.path.join(sessions_dir, d, '.interrupted_session')
Nicola Gatto's avatar
Nicola Gatto committed
25 26 27 28 29 30 31 32
            if os.path.isdir(interrupted_session_dir):
                resume = raw_input('Interrupted session from {} found. Do you want to resume? (y/n) '.format(d))
                if resume == 'y':
                    resume_session = True
                    resume_directory = interrupted_session_dir
                break
    return resume_session, resume_directory

Nicola Gatto's avatar
Nicola Gatto committed
33

34
if __name__ == "__main__":
Nicola Gatto's avatar
Nicola Gatto committed
35
    agent_name = 'cartpole_master_dqn'
Nicola Gatto's avatar
Nicola Gatto committed
36
    # Prepare output directory and logger
Nicola Gatto's avatar
Nicola Gatto committed
37 38 39 40 41
    all_output_dir = os.path.join('model', agent_name)
    output_directory = os.path.join(
        all_output_dir,
        time.strftime('%Y-%m-%d-%H-%M-%S',
                      time.localtime(time.time())))
Nicola Gatto's avatar
Nicola Gatto committed
42 43 44 45
    ArchLogger.set_output_directory(output_directory)
    ArchLogger.set_logger_name(agent_name)
    ArchLogger.set_output_level(ArchLogger.INFO)

46 47
    env = reinforcement_learning.environment.GymEnvironment('CartPole-v0')

Nicola Gatto's avatar
Nicola Gatto committed
48 49 50
    context = mx.cpu()
    qnet_creator = CNNCreator_cartpole_master_dqn.CNNCreator_cartpole_master_dqn()
    qnet_creator.construct(context)
51

Nicola Gatto's avatar
Nicola Gatto committed
52 53 54
    agent_params = {
        'environment': env,
        'replay_memory_params': {
Nicola Gatto's avatar
Nicola Gatto committed
55 56 57 58
            'method': 'buffer',
            'memory_size': 10000,
            'sample_size': 32,
            'state_dtype': 'float32',
Nicola Gatto's avatar
Nicola Gatto committed
59
            'action_dtype': 'uint8',
Nicola Gatto's avatar
Nicola Gatto committed
60
            'rewards_dtype': 'float32'
Nicola Gatto's avatar
Nicola Gatto committed
61 62 63 64 65 66 67 68 69 70
        },
        'strategy_params': {
            'method':'epsgreedy',
            'epsilon': 1,
            'min_epsilon': 0.01,
            'epsilon_decay_method': 'linear',
            'epsilon_decay': 0.01,
        },
        'agent_name': agent_name,
        'verbose': True,
Nicola Gatto's avatar
Nicola Gatto committed
71
        'output_directory': output_directory,
Nicola Gatto's avatar
Nicola Gatto committed
72 73 74 75 76 77 78 79 80
        'state_dim': (4,),
        'action_dim': (2,),
        'ctx': 'cpu',
        'discount_factor': 0.999,
        'training_episodes': 160,
        'train_interval': 1,
        'snapshot_interval': 20,
        'max_episode_step': 250,
        'target_score': 185.5,
Nicola Gatto's avatar
Nicola Gatto committed
81
        'qnet':qnet_creator.networks[0],
Nicola Gatto's avatar
Nicola Gatto committed
82 83
        'use_fix_target': True,
        'target_update_interval': 200,
Nicola Gatto's avatar
Nicola Gatto committed
84
        'loss_function': 'huber',
Nicola Gatto's avatar
Nicola Gatto committed
85 86 87 88
        'optimizer': 'rmsprop',
        'optimizer_params': {
            'learning_rate': 0.001        },
        'double_dqn': False,
89 90
    }

Nicola Gatto's avatar
Nicola Gatto committed
91
    resume, resume_directory = resume_session(all_output_dir)
Nicola Gatto's avatar
Nicola Gatto committed
92

Nicola Gatto's avatar
Nicola Gatto committed
93
    if resume:
Nicola Gatto's avatar
Nicola Gatto committed
94 95
        output_directory, _ = os.path.split(resume_directory)
        ArchLogger.set_output_directory(output_directory)
Nicola Gatto's avatar
Nicola Gatto committed
96 97 98
        resume_agent_params = {
            'session_dir': resume_directory,
            'environment': env,
99
            'net': qnet_creator.networks[0],
Nicola Gatto's avatar
Nicola Gatto committed
100 101
        }
        agent = DqnAgent.resume_from_session(**resume_agent_params)
Nicola Gatto's avatar
Nicola Gatto committed
102
    else:
Nicola Gatto's avatar
Nicola Gatto committed
103
        agent = DqnAgent(**agent_params)
Nicola Gatto's avatar
Nicola Gatto committed
104 105 106 107 108 109 110

    signal_handler = AgentSignalHandler()
    signal_handler.register_agent(agent)

    train_successful = agent.train()

    if train_successful:
Nicola Gatto's avatar
Nicola Gatto committed
111
        agent.export_best_network(path=qnet_creator._model_dir_ + qnet_creator._model_prefix_ + '_0_newest', epoch=0)