Commit 7a329bee authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Add test for TD3

parent 21ccbf74
......@@ -300,6 +300,35 @@ public class GenerationTest extends AbstractSymtabTest {
);
}
@Test
public void testTd3Config() {
Log.getFindings().clear();
Path modelPath = Paths.get("src/test/resources/valid_tests/td3");
CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator);
TrainedArchitecture trainedArchitecture = TrainedArchitectureMockFactory.createTrainedArchitectureMock();
trainGenerator.generate(modelPath, "TD3Config", trainedArchitecture);
assertTrue(Log.getFindings().stream().noneMatch(Finding::isError));
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code/td3"),
Arrays.asList(
"CNNTrainer_tD3Config.py",
"start_training.sh",
"reinforcement_learning/CNNCreator_CriticNetwork.py",
"reinforcement_learning/CNNNet_CriticNetwork.py",
"reinforcement_learning/__init__.py",
"reinforcement_learning/strategy.py",
"reinforcement_learning/agent.py",
"reinforcement_learning/environment.py",
"reinforcement_learning/replay_memory.py",
"reinforcement_learning/util.py",
"reinforcement_learning/cnnarch_logger.py"
)
);
}
@Test
public void testRosDdpgConfig() {
Log.getFindings().clear();
......
from reinforcement_learning.agent import TwinDelayedDdpgAgent
from reinforcement_learning.util import AgentSignalHandler
from reinforcement_learning.cnnarch_logger import ArchLogger
from reinforcement_learning.CNNCreator_CriticNetwork import CNNCreator_CriticNetwork
import reinforcement_learning.environment
import CNNCreator_tD3Config
import os
import sys
import re
import time
import numpy as np
import mxnet as mx
def resume_session(sessions_dir):
resume_session = False
resume_directory = None
if os.path.isdir(sessions_dir):
regex = re.compile(r'\d\d\d\d-\d\d-\d\d-\d\d-\d\d')
dir_content = os.listdir(sessions_dir)
session_files = filter(regex.search, dir_content)
session_files.sort(reverse=True)
for d in session_files:
interrupted_session_dir = os.path.join(sessions_dir, d, '.interrupted_session')
if os.path.isdir(interrupted_session_dir):
resume = raw_input('Interrupted session from {} found. Do you want to resume? (y/n) '.format(d))
if resume == 'y':
resume_session = True
resume_directory = interrupted_session_dir
break
return resume_session, resume_directory
if __name__ == "__main__":
agent_name = 'tD3Config'
# Prepare output directory and logger
all_output_dir = os.path.join('model', agent_name)
output_directory = os.path.join(
all_output_dir,
time.strftime('%Y-%m-%d-%H-%M-%S',
time.localtime(time.time())))
ArchLogger.set_output_directory(output_directory)
ArchLogger.set_logger_name(agent_name)
ArchLogger.set_output_level(ArchLogger.INFO)
env = reinforcement_learning.environment.GymEnvironment('CartPole-v1')
context = mx.cpu()
actor_creator = CNNCreator_tD3Config.CNNCreator_tD3Config()
actor_creator.construct(context)
critic_creator = CNNCreator_CriticNetwork()
critic_creator.construct(context)
agent_params = {
'environment': env,
'replay_memory_params': {
'method': 'online',
'state_dtype': 'float32',
'action_dtype': 'float32',
'rewards_dtype': 'float32'
},
'strategy_params': {
'method':'gaussian',
'epsilon': 1,
'min_epsilon': 0.001,
'epsilon_decay_method': 'linear',
'epsilon_decay': 0.0001,
'epsilon_decay_start': 50,
'epsilon_decay_per_step': True,
'noise_variance': 0.3,
'action_low': -1,
'action_high': 1,
},
'agent_name': agent_name,
'verbose': True,
'output_directory': output_directory,
'state_dim': (8,),
'action_dim': (3,),
'actor': actor_creator.networks[0],
'critic': critic_creator.networks[0],
'soft_target_update_rate': 0.001,
'actor_optimizer': 'adam',
'actor_optimizer_params': {
'learning_rate_minimum': 5.0E-5,
'learning_rate_policy': 'step',
'learning_rate': 1.0E-4,
'learning_rate_decay': 0.9},
'critic_optimizer': 'rmsprop',
'critic_optimizer_params': {
'learning_rate_minimum': 1.0E-4,
'learning_rate_policy': 'step',
'learning_rate': 0.001,
'learning_rate_decay': 0.5},
'policy_noise': 0.1,
'noise_clip': 0.8,
'policy_delay': 4,
}
resume, resume_directory = resume_session(all_output_dir)
if resume:
output_directory, _ = os.path.split(resume_directory)
ArchLogger.set_output_directory(output_directory)
resume_agent_params = {
'session_dir': resume_directory,
'environment': env,
'actor': actor_creator.networks[0],
'critic': critic_creator.networks[0]
}
agent = TwinDelayedDdpgAgent.resume_from_session(**resume_agent_params)
else:
agent = TwinDelayedDdpgAgent(**agent_params)
signal_handler = AgentSignalHandler()
signal_handler.register_agent(agent)
train_successful = agent.train()
if train_successful:
agent.export_best_network(path=actor_creator._model_dir_ + actor_creator._model_prefix_ + '_0_newest', epoch=0)
import mxnet as mx
import logging
import os
from CNNNet_CriticNetwork import Net_0
class CNNCreator_CriticNetwork:
_model_dir_ = "model/CriticNetwork/"
_model_prefix_ = "model"
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.networks = {}
def load(self, context):
earliestLastEpoch = None
for i, network in self.networks.items():
lastEpoch = 0
param_file = None
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-symbol.json")
except OSError:
pass
if os.path.isdir(self._model_dir_):
for file in os.listdir(self._model_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
if param_file is None:
earliestLastEpoch = 0
else:
logging.info("Loading checkpoint: " + param_file)
network.load_parameters(self._model_dir_ + param_file)
if earliestLastEpoch == None or lastEpoch < earliestLastEpoch:
earliestLastEpoch = lastEpoch
return earliestLastEpoch
def construct(self, context, data_mean=None, data_std=None):
self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[0].hybridize()
self.networks[0](mx.nd.zeros((1, 8,), ctx=context), mx.nd.zeros((1, 3,), ctx=context))
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
import mxnet as mx
import numpy as np
from mxnet import gluon
class OneHot(gluon.HybridBlock):
def __init__(self, size, **kwargs):
super(OneHot, self).__init__(**kwargs)
with self.name_scope():
self.size = size
def hybrid_forward(self, F, x):
return F.one_hot(indices=F.argmax(data=x, axis=1), depth=self.size)
class Softmax(gluon.HybridBlock):
def __init__(self, **kwargs):
super(Softmax, self).__init__(**kwargs)
def hybrid_forward(self, F, x):
return F.softmax(x)
class Split(gluon.HybridBlock):
def __init__(self, num_outputs, axis=1, **kwargs):
super(Split, self).__init__(**kwargs)
with self.name_scope():
self.axis = axis
self.num_outputs = num_outputs
def hybrid_forward(self, F, x):
return F.split(data=x, axis=self.axis, num_outputs=self.num_outputs)
class Concatenate(gluon.HybridBlock):
def __init__(self, dim=1, **kwargs):
super(Concatenate, self).__init__(**kwargs)
with self.name_scope():
self.dim = dim
def hybrid_forward(self, F, *x):
return F.concat(*x, dim=self.dim)
class ZScoreNormalization(gluon.HybridBlock):
def __init__(self, data_mean, data_std, **kwargs):
super(ZScoreNormalization, self).__init__(**kwargs)
with self.name_scope():
self.data_mean = self.params.get('data_mean', shape=data_mean.shape,
init=mx.init.Constant(data_mean.asnumpy().tolist()), differentiable=False)
self.data_std = self.params.get('data_std', shape=data_mean.shape,
init=mx.init.Constant(data_std.asnumpy().tolist()), differentiable=False)
def hybrid_forward(self, F, x, data_mean, data_std):
x = F.broadcast_sub(x, data_mean)
x = F.broadcast_div(x, data_std)
return x
class Padding(gluon.HybridBlock):
def __init__(self, padding, **kwargs):
super(Padding, self).__init__(**kwargs)
with self.name_scope():
self.pad_width = padding
def hybrid_forward(self, F, x):
x = F.pad(data=x,
mode='constant',
pad_width=self.pad_width,
constant_value=0)
return x
class NoNormalization(gluon.HybridBlock):
def __init__(self, **kwargs):
super(NoNormalization, self).__init__(**kwargs)
def hybrid_forward(self, F, x):
return x
class Net_0(gluon.HybridBlock):
def __init__(self, data_mean=None, data_std=None, **kwargs):
super(Net_0, self).__init__(**kwargs)
self.last_layers = {}
with self.name_scope():
if data_mean:
assert(data_std)
self.input_normalization_state = ZScoreNormalization(data_mean=data_mean['state'],
data_std=data_std['state'])
else:
self.input_normalization_state = NoNormalization()
self.fc2_1_ = gluon.nn.Dense(units=300, use_bias=True)
# fc2_1_, output shape: {[300,1,1]}
self.relu2_1_ = gluon.nn.Activation(activation='relu')
self.fc3_1_ = gluon.nn.Dense(units=600, use_bias=True)
# fc3_1_, output shape: {[600,1,1]}
if data_mean:
assert(data_std)
self.input_normalization_action = ZScoreNormalization(data_mean=data_mean['action'],
data_std=data_std['action'])
else:
self.input_normalization_action = NoNormalization()
self.fc2_2_ = gluon.nn.Dense(units=600, use_bias=True)
# fc2_2_, output shape: {[600,1,1]}
self.fc4_ = gluon.nn.Dense(units=600, use_bias=True)
# fc4_, output shape: {[600,1,1]}
self.relu4_ = gluon.nn.Activation(activation='relu')
self.fc5_ = gluon.nn.Dense(units=1, use_bias=True)
# fc5_, output shape: {[1,1,1]}
def hybrid_forward(self, F, state, action):
outputs = []
state = self.input_normalization_state(state)
fc2_1_ = self.fc2_1_(state)
relu2_1_ = self.relu2_1_(fc2_1_)
fc3_1_ = self.fc3_1_(relu2_1_)
action = self.input_normalization_action(action)
fc2_2_ = self.fc2_2_(action)
add4_ = fc3_1_ + fc2_2_
fc4_ = self.fc4_(add4_)
relu4_ = self.relu4_(fc4_)
fc5_ = self.fc5_(relu4_)
outputs.append(fc5_)
return outputs[0]
import logging
import sys
import os
import util
class ArchLogger(object):
_logger = None
__output_level = logging.INFO
__logger_name = 'agent'
__output_directory = '.'
__append = True
__logformat = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
__dateformat = '%d-%b-%y %H:%M:%S'
INFO = logging.INFO
DEBUG = logging.DEBUG
@staticmethod
def set_output_level(output_level):
assert output_level is not None
ArchLogger.__output_level = output_level
@staticmethod
def set_logger_name(logger_name):
assert logger_name is not None
ArchLogger.__logger_name = logger_name
@staticmethod
def set_output_directory(output_directory):
assert output_directory is not None
ArchLogger.__output_directory = output_directory
@staticmethod
def set_append(append):
assert append is not None
ArchLogger.__append = append
@staticmethod
def set_log_format(logformat, dateformat):
assert logformat is not None
assert dateformat is not None
ArchLogger.__logformat = logformat
ArchLogger.__dateformat = dateformat
@staticmethod
def init_logger(make_log_file=True):
assert ArchLogger._logger is None, 'Logger init already called'
filemode = 'a' if ArchLogger.__append else 'w'
formatter = logging.Formatter(
fmt=ArchLogger.__logformat, datefmt=ArchLogger.__dateformat)
logger = logging.getLogger(ArchLogger.__logger_name)
logger.propagate = False
if not logger.handlers:
logger.setLevel(ArchLogger.__output_level)
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setLevel(ArchLogger.__output_level)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if make_log_file:
util.make_directory_if_not_exist(ArchLogger.__output_directory)
log_file = os.path.join(
ArchLogger.__output_directory,
ArchLogger.__logger_name + '.log')
file_handler = logging.FileHandler(log_file, mode=filemode)
file_handler.setLevel(ArchLogger.__output_level)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
ArchLogger._logger = logger
@staticmethod
def get_logger():
if ArchLogger._logger is None:
ArchLogger.init_logger()
assert ArchLogger._logger is not None
return ArchLogger._logger
if __name__ == "__main__":
print('=== Test logger ===')
ArchLogger.set_logger_name('TestLogger')
ArchLogger.set_output_directory('test_log')
ArchLogger.init_logger()
logger = ArchLogger.get_logger()
logger.warning('This is a warning')
logger.debug('This is a debug information, which you should not see')
logger.info('This is a normal information')
assert os.path.exists('test_log')\
and os.path.isfile(os.path.join('test_log', 'TestLogger.log')),\
'Test failed: No logfile exists'
import shutil
shutil.rmtree('test_log')
\ No newline at end of file
import abc
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Environment:
__metaclass__ = abc.ABCMeta
def __init__(self):
pass
@abc.abstractmethod
def reset(self):
pass
@abc.abstractmethod
def step(self, action):
pass
@abc.abstractmethod
def close(self):
pass
import gym
class GymEnvironment(Environment):
def __init__(self, env_name, **kwargs):
super(GymEnvironment, self).__init__(**kwargs)
self.__seed = 42
self.__env = gym.make(env_name)
self.__env.seed(self.__seed)
@property
def state_dim(self):
return self.__env.observation_space.shape
@property
def number_of_actions(self):
return self.__env.action_space.n
@property
def rewards_dtype(self):
return 'float32'
def reset(self):
return self.__env.reset()
def step(self, action):
return self.__env.step(action)
def close(self):
self.__env.close()
def action_space(self):
self.__env.action_space
def is_in_action_space(self, action):
return self.__env.action_space.contains(action)
def sample_action(self):
return self.__env.action_space.sample()
def render(self):
self.__env.render()
import numpy as np
class ReplayMemoryBuilder(object):
def __init__(self):
self.__supported_methods = ['online', 'buffer', 'combined']
def build_by_params(
self,
state_dim,
method='online',
state_dtype='float32',
action_dim=(1,),
action_dtype='uint8',
rewards_dtype='float32',
memory_size=1000,
sample_size=32
):
assert state_dim is not None
assert action_dim is not None
assert method in self.__supported_methods
if method == 'online':
return self.build_online_memory(
state_dim=state_dim, state_dtype=state_dtype,
action_dtype=action_dtype, action_dim=action_dim,
rewards_dtype=rewards_dtype)
else:
assert memory_size is not None and memory_size > 0
assert sample_size is not None and sample_size > 0
if method == 'buffer':
return self.build_buffered_memory(
state_dim=state_dim, sample_size=sample_size,
memory_size=memory_size, state_dtype=state_dtype,
action_dim=action_dim, action_dtype=action_dtype,
rewards_dtype=rewards_dtype)
else:
return self.build_combined_memory(
state_dim=state_dim, sample_size=sample_size,
memory_size=memory_size, state_dtype=state_dtype,
action_dim=action_dim, action_dtype=action_dtype,
rewards_dtype=rewards_dtype)
def build_buffered_memory(