Trainer.ftl 6.91 KB
Newer Older
Bernhard Rumpe's avatar
BR-sy    
Bernhard Rumpe committed
1
<#-- (c) https://github.com/MontiCore/monticore -->
2
3
<#setting number_format="computer">
<#assign config = configurations[0]>
4
<#assign rlAgentType=config.rlAlgorithm?switch("dqn", "DqnAgent", "ddpg", "DdpgAgent", "td3", "TwinDelayedDdpgAgent")>
Nicola Gatto's avatar
Nicola Gatto committed
5
from ${rlFrameworkModule}.agent import ${rlAgentType}
6
from ${rlFrameworkModule}.util import AgentSignalHandler
Nicola Gatto's avatar
Nicola Gatto committed
7
from ${rlFrameworkModule}.cnnarch_logger import ArchLogger
8
<#if config.rlAlgorithm=="ddpg" || config.rlAlgorithm=="td3">
Nicola Gatto's avatar
Nicola Gatto committed
9
10
from ${rlFrameworkModule}.CNNCreator_${criticInstanceName} import CNNCreator_${criticInstanceName}
</#if>
11
12
13
14
15
16
import ${rlFrameworkModule}.environment
import CNNCreator_${config.instanceName}

import os
import sys
import re
Nicola Gatto's avatar
Nicola Gatto committed
17
18
import time
import numpy as np
19
20
21
import mxnet as mx


Nicola Gatto's avatar
Nicola Gatto committed
22
def resume_session(sessions_dir):
23
24
    resume_session = False
    resume_directory = None
Nicola Gatto's avatar
Nicola Gatto committed
25
    if os.path.isdir(sessions_dir):
26
        regex = re.compile(r'\d\d\d\d-\d\d-\d\d-\d\d-\d\d')
Nicola Gatto's avatar
Nicola Gatto committed
27
        dir_content = os.listdir(sessions_dir)
28
29
30
        session_files = filter(regex.search, dir_content)
        session_files.sort(reverse=True)
        for d in session_files:
Nicola Gatto's avatar
Nicola Gatto committed
31
            interrupted_session_dir = os.path.join(sessions_dir, d, '.interrupted_session')
32
33
34
35
36
37
38
39
            if os.path.isdir(interrupted_session_dir):
                resume = raw_input('Interrupted session from {} found. Do you want to resume? (y/n) '.format(d))
                if resume == 'y':
                    resume_session = True
                    resume_directory = interrupted_session_dir
                break
    return resume_session, resume_directory

Nicola Gatto's avatar
Nicola Gatto committed
40

41
if __name__ == "__main__":
Nicola Gatto's avatar
Nicola Gatto committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
<#if (config.agentName)??>
    agent_name = '${config.agentName}'
<#else>
    agent_name = '${config.instanceName}'
</#if>
    # Prepare output directory and logger
    all_output_dir = os.path.join('model', agent_name)
    output_directory = os.path.join(
        all_output_dir,
        time.strftime('%Y-%m-%d-%H-%M-%S',
                      time.localtime(time.time())))
    ArchLogger.set_output_directory(output_directory)
    ArchLogger.set_logger_name(agent_name)
    ArchLogger.set_output_level(ArchLogger.INFO)

57
58
59
60
<#if config.environment.environment == "gym">
    env = ${rlFrameworkModule}.environment.GymEnvironment(<#if config.environment.name??>'${config.environment.name}'<#else>'CartPole-v0'</#if>)
<#else>
    env_params = {
Nicola Gatto's avatar
Nicola Gatto committed
61
        'ros_node_name': '${config.instanceName}TrainerNode',
62
<#if config.environment.state_topic??>
Nicola Gatto's avatar
Nicola Gatto committed
63
        'state_topic': '${config.environment.state_topic}',
64
65
</#if>
<#if config.environment.action_topic??>
Nicola Gatto's avatar
Nicola Gatto committed
66
        'action_topic': '${config.environment.action_topic}',
67
68
</#if>
<#if config.environment.reset_topic??>
Nicola Gatto's avatar
Nicola Gatto committed
69
        'reset_topic': '${config.environment.reset_topic}',
70
71
</#if>
<#if config.environment.terminal_state_topic??>
Nicola Gatto's avatar
Nicola Gatto committed
72
73
74
75
        'terminal_state_topic': '${config.environment.terminal_state_topic}',
</#if>
<#if config.environment.reward_topic??>
        'reward_topic': '${config.environment.reward_topic}',
76
77
78
79
</#if>
    }
    env = ${rlFrameworkModule}.environment.RosEnvironment(**env_params)
</#if>
Nicola Gatto's avatar
Nicola Gatto committed
80

81
82
83
84
85
<#if (config.context)??>
    context = mx.${config.context}()
<#else>
    context = mx.cpu()
</#if>
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
<#if (config.configuration.initializer)??>
    initializer_params = {
<#list config.initializerParams?keys as param>
        '${param}': ${config.initializerParams[param]}<#sep>,
</#list>
    }
    initializer = mx.init.${config.initializerName?capitalize}(**initializer_params)
<#else>
    initializer = mx.init.Normal()
</#if>
<#if (config.configuration.criticInitializer)??>
    critic_initializer_params = {
<#list config.criticInitializerParams?keys as param>
        '${param}': ${config.criticInitializerParams[param]}<#sep>,
</#list>
    }
    critic_initializer = mx.init.${config.criticInitializerName?capitalize}(**critic_initializer_params)
<#else>
    critic_initializer = mx.init.Normal()
</#if>
Nicola Gatto's avatar
Nicola Gatto committed
106
107
<#if config.rlAlgorithm == "dqn">
    qnet_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
108
    qnet_creator.setWeightInitializer(initializer)
Nicola Gatto's avatar
Nicola Gatto committed
109
    qnet_creator.construct(context)
110
<#else>
Nicola Gatto's avatar
Nicola Gatto committed
111
    actor_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
112
    actor_creator.setWeightInitializer(initializer)
Nicola Gatto's avatar
Nicola Gatto committed
113
114
    actor_creator.construct(context)
    critic_creator = CNNCreator_${criticInstanceName}()
115
    critic_creator.setWeightInitializer(critic_initializer)
Nicola Gatto's avatar
Nicola Gatto committed
116
    critic_creator.construct(context)
117
118
</#if>

Nicola Gatto's avatar
Nicola Gatto committed
119
120
121
122
123
124
125
126
127
128
129
130
131
    agent_params = {
        'environment': env,
        'replay_memory_params': {
<#include "params/ReplayMemoryParams.ftl">
        },
        'strategy_params': {
<#include "params/StrategyParams.ftl">
        },
        'agent_name': agent_name,
        'verbose': True,
        'output_directory': output_directory,
        'state_dim': (<#list config.stateDim as d>${d},</#list>),
        'action_dim': (<#list config.actionDim as d>${d},</#list>),
132
<#if (config.context)??>
Nicola Gatto's avatar
Nicola Gatto committed
133
        'ctx': '${config.context}',
134
135
</#if>
<#if (config.discountFactor)??>
Nicola Gatto's avatar
Nicola Gatto committed
136
        'discount_factor': ${config.discountFactor},
137
138
</#if>
<#if (config.numEpisodes)??>
Nicola Gatto's avatar
Nicola Gatto committed
139
        'training_episodes': ${config.numEpisodes},
140
141
</#if>
<#if (config.trainingInterval)??>
Nicola Gatto's avatar
Nicola Gatto committed
142
        'train_interval': ${config.trainingInterval},
143
</#if>
Nicola Gatto's avatar
Nicola Gatto committed
144
145
<#if (config.startTrainingAt)??>
        'start_training': ${config.startTrainingAt},
146
147
</#if>
<#if (config.snapshotInterval)??>
Nicola Gatto's avatar
Nicola Gatto committed
148
        'snapshot_interval': ${config.snapshotInterval},
149
150
</#if>
<#if (config.numMaxSteps)??>
Nicola Gatto's avatar
Nicola Gatto committed
151
152
153
154
155
156
157
        'max_episode_step': ${config.numMaxSteps},
</#if>
<#if (config.evaluationSamples)??>
        'evaluation_samples': ${config.evaluationSamples},
</#if>
<#if (config.outputDirectory)??>
        'output_directory': ${config.outputDirectory},
158
159
</#if>
<#if (config.targetScore)??>
Nicola Gatto's avatar
Nicola Gatto committed
160
161
162
163
        'target_score': ${config.targetScore},
</#if>
<#if (config.rlAlgorithm == "dqn")>
<#include "params/DqnAgentParams.ftl">
164
<#elseif config.rlAlgorithm == "ddpg">
Nicola Gatto's avatar
Nicola Gatto committed
165
<#include "params/DdpgAgentParams.ftl">
166
167
<#else>
<#include "params/Td3AgentParams.ftl">
Nicola Gatto's avatar
Nicola Gatto committed
168
169
170
171
172
173
174
175
176
177
178
179
</#if>
    }

    resume, resume_directory = resume_session(all_output_dir)

    if resume:
        output_directory, _ = os.path.split(resume_directory)
        ArchLogger.set_output_directory(output_directory)
        resume_agent_params = {
            'session_dir': resume_directory,
            'environment': env,
<#if config.rlAlgorithm == "dqn">
180
            'net': qnet_creator.networks[0],
181
<#else>
182
183
            'actor': actor_creator.networks[0],
            'critic': critic_creator.networks[0]
184
</#if>
Nicola Gatto's avatar
Nicola Gatto committed
185
186
187
188
        }
        agent = ${rlAgentType}.resume_from_session(**resume_agent_params)
    else:
        agent = ${rlAgentType}(**agent_params)
189
190
191
192
193
194
195

    signal_handler = AgentSignalHandler()
    signal_handler.register_agent(agent)

    train_successful = agent.train()

    if train_successful:
Nicola Gatto's avatar
Nicola Gatto committed
196
<#if (config.rlAlgorithm == "dqn")>
197
        agent.export_best_network(path=qnet_creator._model_dir_ + qnet_creator._model_prefix_ + '_0_newest', epoch=0)
Nicola Gatto's avatar
Nicola Gatto committed
198
<#else>
199
        agent.export_best_network(path=actor_creator._model_dir_ + actor_creator._model_prefix_ + '_0_newest', epoch=0)
Bernhard Rumpe's avatar
BR-sy    
Bernhard Rumpe committed
200
</#if>