CNNTrainer_tD3Config.py 4.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from reinforcement_learning.agent import TwinDelayedDdpgAgent
from reinforcement_learning.util import AgentSignalHandler
from reinforcement_learning.cnnarch_logger import ArchLogger
from reinforcement_learning.CNNCreator_CriticNetwork import CNNCreator_CriticNetwork
import reinforcement_learning.environment
import CNNCreator_tD3Config

import os
import sys
import re
import time
import numpy as np
import mxnet as mx


def resume_session(sessions_dir):
    resume_session = False
    resume_directory = None
    if os.path.isdir(sessions_dir):
        regex = re.compile(r'\d\d\d\d-\d\d-\d\d-\d\d-\d\d')
        dir_content = os.listdir(sessions_dir)
        session_files = filter(regex.search, dir_content)
        session_files.sort(reverse=True)
        for d in session_files:
            interrupted_session_dir = os.path.join(sessions_dir, d, '.interrupted_session')
            if os.path.isdir(interrupted_session_dir):
                resume = raw_input('Interrupted session from {} found. Do you want to resume? (y/n) '.format(d))
                if resume == 'y':
                    resume_session = True
                    resume_directory = interrupted_session_dir
                break
    return resume_session, resume_directory


if __name__ == "__main__":
    agent_name = 'tD3Config'
    # Prepare output directory and logger
    all_output_dir = os.path.join('model', agent_name)
    output_directory = os.path.join(
        all_output_dir,
        time.strftime('%Y-%m-%d-%H-%M-%S',
                      time.localtime(time.time())))
    ArchLogger.set_output_directory(output_directory)
    ArchLogger.set_logger_name(agent_name)
    ArchLogger.set_output_level(ArchLogger.INFO)

    env = reinforcement_learning.environment.GymEnvironment('CartPole-v1')

    context = mx.cpu()
50
51
    initializer = mx.init.Normal()
    critic_initializer = mx.init.Normal()
52
    actor_creator = CNNCreator_tD3Config.CNNCreator_tD3Config()
53
    actor_creator.setWeightInitializer(initializer)
54
55
    actor_creator.construct(context)
    critic_creator = CNNCreator_CriticNetwork()
56
    critic_creator.setWeightInitializer(critic_initializer)
57
58
59
60
61
62
63
64
65
    critic_creator.construct(context)

    agent_params = {
        'environment': env,
        'replay_memory_params': {
            'method': 'online',
            'state_dtype': 'float32',
            'action_dtype': 'float32',
            'rewards_dtype': 'float32'
66

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        },
        'strategy_params': {
            'method':'gaussian',
            'epsilon': 1,
            'min_epsilon': 0.001,
            'epsilon_decay_method': 'linear',
            'epsilon_decay': 0.0001,
            'epsilon_decay_start': 50,
            'epsilon_decay_per_step': True,
            'noise_variance': 0.3,
            'action_low': -1,
            'action_high': 1,
        },
        'agent_name': agent_name,
        'verbose': True,
        'output_directory': output_directory,
        'state_dim': (8,),
        'action_dim': (3,),
        'actor': actor_creator.networks[0],
        'critic': critic_creator.networks[0],
        'soft_target_update_rate': 0.001,
        'actor_optimizer': 'adam',
        'actor_optimizer_params': {
            'learning_rate_minimum': 5.0E-5,
            'learning_rate_policy': 'step',
            'learning_rate': 1.0E-4,
            'learning_rate_decay': 0.9},
        'critic_optimizer': 'rmsprop',
        'critic_optimizer_params': {
            'learning_rate_minimum': 1.0E-4,
            'learning_rate_policy': 'step',
            'learning_rate': 0.001,
            'learning_rate_decay': 0.5},
        'policy_noise': 0.1,
        'noise_clip': 0.8,
        'policy_delay': 4,
    }

    resume, resume_directory = resume_session(all_output_dir)

    if resume:
        output_directory, _ = os.path.split(resume_directory)
        ArchLogger.set_output_directory(output_directory)
        resume_agent_params = {
            'session_dir': resume_directory,
            'environment': env,
            'actor': actor_creator.networks[0],
            'critic': critic_creator.networks[0]
        }
        agent = TwinDelayedDdpgAgent.resume_from_session(**resume_agent_params)
    else:
        agent = TwinDelayedDdpgAgent(**agent_params)

    signal_handler = AgentSignalHandler()
    signal_handler.register_agent(agent)

    train_successful = agent.train()

    if train_successful:
        agent.export_best_network(path=actor_creator._model_dir_ + actor_creator._model_prefix_ + '_0_newest', epoch=0)