Trainer.ftl 7.03 KB
Newer Older
Bernhard Rumpe's avatar
BR-sy    
Bernhard Rumpe committed
1
<#-- (c) https://github.com/MontiCore/monticore -->
2
3
<#setting number_format="computer">
<#assign config = configurations[0]>
4
<#assign rlAgentType=config.rlAlgorithm?switch("dqn", "DqnAgent", "ddpg", "DdpgAgent", "td3", "TwinDelayedDdpgAgent")>
Nicola Gatto's avatar
Nicola Gatto committed
5
from ${rlFrameworkModule}.agent import ${rlAgentType}
6
from ${rlFrameworkModule}.util import AgentSignalHandler
Nicola Gatto's avatar
Nicola Gatto committed
7
from ${rlFrameworkModule}.cnnarch_logger import ArchLogger
8
<#if config.rlAlgorithm=="ddpg" || config.rlAlgorithm=="td3">
Nicola Gatto's avatar
Nicola Gatto committed
9
10
from ${rlFrameworkModule}.CNNCreator_${criticInstanceName} import CNNCreator_${criticInstanceName}
</#if>
11
12
13
14
15
16
import ${rlFrameworkModule}.environment
import CNNCreator_${config.instanceName}

import os
import sys
import re
Nicola Gatto's avatar
Nicola Gatto committed
17
18
import time
import numpy as np
19
20
21
import mxnet as mx


Nicola Gatto's avatar
Nicola Gatto committed
22
def resume_session(sessions_dir):
23
24
    resume_session = False
    resume_directory = None
Nicola Gatto's avatar
Nicola Gatto committed
25
    if os.path.isdir(sessions_dir):
26
        regex = re.compile(r'\d\d\d\d-\d\d-\d\d-\d\d-\d\d')
Nicola Gatto's avatar
Nicola Gatto committed
27
        dir_content = os.listdir(sessions_dir)
28
29
30
        session_files = filter(regex.search, dir_content)
        session_files.sort(reverse=True)
        for d in session_files:
Nicola Gatto's avatar
Nicola Gatto committed
31
            interrupted_session_dir = os.path.join(sessions_dir, d, '.interrupted_session')
32
33
34
35
36
37
38
39
            if os.path.isdir(interrupted_session_dir):
                resume = raw_input('Interrupted session from {} found. Do you want to resume? (y/n) '.format(d))
                if resume == 'y':
                    resume_session = True
                    resume_directory = interrupted_session_dir
                break
    return resume_session, resume_directory

Nicola Gatto's avatar
Nicola Gatto committed
40

41
if __name__ == "__main__":
Nicola Gatto's avatar
Nicola Gatto committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
<#if (config.agentName)??>
    agent_name = '${config.agentName}'
<#else>
    agent_name = '${config.instanceName}'
</#if>
    # Prepare output directory and logger
    all_output_dir = os.path.join('model', agent_name)
    output_directory = os.path.join(
        all_output_dir,
        time.strftime('%Y-%m-%d-%H-%M-%S',
                      time.localtime(time.time())))
    ArchLogger.set_output_directory(output_directory)
    ArchLogger.set_logger_name(agent_name)
    ArchLogger.set_output_level(ArchLogger.INFO)

57
58
59
60
<#if config.environment.environment == "gym">
    env = ${rlFrameworkModule}.environment.GymEnvironment(<#if config.environment.name??>'${config.environment.name}'<#else>'CartPole-v0'</#if>)
<#else>
    env_params = {
Nicola Gatto's avatar
Nicola Gatto committed
61
        'ros_node_name': '${config.instanceName}TrainerNode',
62
<#if config.environment.state_topic??>
Nicola Gatto's avatar
Nicola Gatto committed
63
        'state_topic': '${config.environment.state_topic}',
64
65
</#if>
<#if config.environment.action_topic??>
Nicola Gatto's avatar
Nicola Gatto committed
66
        'action_topic': '${config.environment.action_topic}',
67
68
</#if>
<#if config.environment.reset_topic??>
Nicola Gatto's avatar
Nicola Gatto committed
69
        'reset_topic': '${config.environment.reset_topic}',
70
71
</#if>
<#if config.environment.terminal_state_topic??>
Nicola Gatto's avatar
Nicola Gatto committed
72
73
74
75
        'terminal_state_topic': '${config.environment.terminal_state_topic}',
</#if>
<#if config.environment.reward_topic??>
        'reward_topic': '${config.environment.reward_topic}',
76
77
78
79
</#if>
    }
    env = ${rlFrameworkModule}.environment.RosEnvironment(**env_params)
</#if>
Nicola Gatto's avatar
Nicola Gatto committed
80

81
82
83
84
85
<#if (config.context)??>
    context = mx.${config.context}()
<#else>
    context = mx.cpu()
</#if>
86
87
88
89
90
91
92
93
94
95
<#if (config.configuration.initializer)??>
    initializer_params = {
<#list config.initializerParams?keys as param>
        '${param}': ${config.initializerParams[param]}<#sep>,
</#list>
    }
    initializer = mx.init.${config.initializerName?capitalize}(**initializer_params)
<#else>
    initializer = mx.init.Normal()
</#if>
96
<#if config.rlAlgorithm=="ddpg" || config.rlAlgorithm=="td3">
97
98
99
100
101
102
103
104
105
106
<#if (config.configuration.criticInitializer)??>
    critic_initializer_params = {
<#list config.criticInitializerParams?keys as param>
        '${param}': ${config.criticInitializerParams[param]}<#sep>,
</#list>
    }
    critic_initializer = mx.init.${config.criticInitializerName?capitalize}(**critic_initializer_params)
<#else>
    critic_initializer = mx.init.Normal()
</#if>
107
</#if>
Nicola Gatto's avatar
Nicola Gatto committed
108
109
<#if config.rlAlgorithm == "dqn">
    qnet_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
110
    qnet_creator.setWeightInitializer(initializer)
Nicola Gatto's avatar
Nicola Gatto committed
111
    qnet_creator.construct(context)
112
<#elseif config.rlAlgorithm=="ddpg" || config.rlAlgorithm=="td3">
Nicola Gatto's avatar
Nicola Gatto committed
113
    actor_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
114
    actor_creator.setWeightInitializer(initializer)
Nicola Gatto's avatar
Nicola Gatto committed
115
116
    actor_creator.construct(context)
    critic_creator = CNNCreator_${criticInstanceName}()
117
    critic_creator.setWeightInitializer(critic_initializer)
Nicola Gatto's avatar
Nicola Gatto committed
118
    critic_creator.construct(context)
119
120
</#if>

Nicola Gatto's avatar
Nicola Gatto committed
121
122
123
124
125
126
127
128
129
130
131
132
133
    agent_params = {
        'environment': env,
        'replay_memory_params': {
<#include "params/ReplayMemoryParams.ftl">
        },
        'strategy_params': {
<#include "params/StrategyParams.ftl">
        },
        'agent_name': agent_name,
        'verbose': True,
        'output_directory': output_directory,
        'state_dim': (<#list config.stateDim as d>${d},</#list>),
        'action_dim': (<#list config.actionDim as d>${d},</#list>),
134
<#if (config.context)??>
Nicola Gatto's avatar
Nicola Gatto committed
135
        'ctx': '${config.context}',
136
137
</#if>
<#if (config.discountFactor)??>
Nicola Gatto's avatar
Nicola Gatto committed
138
        'discount_factor': ${config.discountFactor},
139
140
</#if>
<#if (config.numEpisodes)??>
Nicola Gatto's avatar
Nicola Gatto committed
141
        'training_episodes': ${config.numEpisodes},
142
143
</#if>
<#if (config.trainingInterval)??>
Nicola Gatto's avatar
Nicola Gatto committed
144
        'train_interval': ${config.trainingInterval},
145
</#if>
Nicola Gatto's avatar
Nicola Gatto committed
146
147
<#if (config.startTrainingAt)??>
        'start_training': ${config.startTrainingAt},
148
149
</#if>
<#if (config.snapshotInterval)??>
Nicola Gatto's avatar
Nicola Gatto committed
150
        'snapshot_interval': ${config.snapshotInterval},
151
152
</#if>
<#if (config.numMaxSteps)??>
Nicola Gatto's avatar
Nicola Gatto committed
153
154
155
156
157
158
159
        'max_episode_step': ${config.numMaxSteps},
</#if>
<#if (config.evaluationSamples)??>
        'evaluation_samples': ${config.evaluationSamples},
</#if>
<#if (config.outputDirectory)??>
        'output_directory': ${config.outputDirectory},
160
161
</#if>
<#if (config.targetScore)??>
Nicola Gatto's avatar
Nicola Gatto committed
162
163
164
165
        'target_score': ${config.targetScore},
</#if>
<#if (config.rlAlgorithm == "dqn")>
<#include "params/DqnAgentParams.ftl">
166
<#elseif config.rlAlgorithm == "ddpg">
Nicola Gatto's avatar
Nicola Gatto committed
167
<#include "params/DdpgAgentParams.ftl">
168
169
<#else>
<#include "params/Td3AgentParams.ftl">
Nicola Gatto's avatar
Nicola Gatto committed
170
171
172
173
174
175
176
177
178
179
180
181
</#if>
    }

    resume, resume_directory = resume_session(all_output_dir)

    if resume:
        output_directory, _ = os.path.split(resume_directory)
        ArchLogger.set_output_directory(output_directory)
        resume_agent_params = {
            'session_dir': resume_directory,
            'environment': env,
<#if config.rlAlgorithm == "dqn">
182
            'net': qnet_creator.networks[0],
183
<#else>
184
185
            'actor': actor_creator.networks[0],
            'critic': critic_creator.networks[0]
186
</#if>
Nicola Gatto's avatar
Nicola Gatto committed
187
188
189
190
        }
        agent = ${rlAgentType}.resume_from_session(**resume_agent_params)
    else:
        agent = ${rlAgentType}(**agent_params)
191
192
193
194
195
196
197

    signal_handler = AgentSignalHandler()
    signal_handler.register_agent(agent)

    train_successful = agent.train()

    if train_successful:
Nicola Gatto's avatar
Nicola Gatto committed
198
<#if (config.rlAlgorithm == "dqn")>
199
        agent.export_best_network(path=qnet_creator._model_dir_ + qnet_creator._model_prefix_ + '_0_newest', epoch=0)
Nicola Gatto's avatar
Nicola Gatto committed
200
<#else>
201
        agent.export_best_network(path=actor_creator._model_dir_ + actor_creator._model_prefix_ + '_0_newest', epoch=0)
Bernhard Rumpe's avatar
BR-sy    
Bernhard Rumpe committed
202
</#if>