Aufgrund eines Versionsupdates wird GitLab am 01.04. zwischen 9:00 und 9:30 Uhr kurzzeitig nicht zur Verfügung stehen. / Due to a version upgrade, GitLab won't be accessible at 01.04. between 9:00 and 9:30 a.m.

Commit 3d680f2e authored by Sebastian Nickels's avatar Sebastian Nickels

Merge

parents b85b4bce d3c1bc00
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>embedded-montiarc-emadl-generator</artifactId> <artifactId>embedded-montiarc-emadl-generator</artifactId>
<version>0.3.3-SNAPSHOT</version> <version>0.3.4-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= --> <!-- == PROJECT DEPENDENCIES ============================================= -->
......
...@@ -9,6 +9,7 @@ import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArch2MxNet; ...@@ -9,6 +9,7 @@ import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArch2MxNet;
import de.monticore.lang.monticar.cnnarch.caffe2generator.CNNArch2Caffe2; import de.monticore.lang.monticar.cnnarch.caffe2generator.CNNArch2Caffe2;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNTrain2MxNet; import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNTrain2MxNet;
import de.monticore.lang.monticar.cnnarch.caffe2generator.CNNTrain2Caffe2; import de.monticore.lang.monticar.cnnarch.caffe2generator.CNNTrain2Caffe2;
import de.monticore.lang.monticar.emadl.generator.reinforcementlearning.RewardFunctionCppGenerator;
import java.util.Optional; import java.util.Optional;
......
...@@ -34,6 +34,8 @@ import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol; ...@@ -34,6 +34,8 @@ import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.SerialCompositeElementSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.SerialCompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNTrain2Gluon; import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNTrain2Gluon;
import de.monticore.lang.monticar.cnnarch.gluongenerator.annotations.ArchitectureAdapter; import de.monticore.lang.monticar.cnnarch.gluongenerator.annotations.ArchitectureAdapter;
import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCoCoChecker;
import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.emadl._cocos.EMADLCocos; import de.monticore.lang.monticar.emadl._cocos.EMADLCocos;
import de.monticore.lang.monticar.generator.FileContent; import de.monticore.lang.monticar.generator.FileContent;
...@@ -115,6 +117,18 @@ public class EMADLGenerator { ...@@ -115,6 +117,18 @@ public class EMADLGenerator {
processedArchitecture = new HashMap<>(); processedArchitecture = new HashMap<>();
setModelsPath( modelPath ); setModelsPath( modelPath );
TaggingResolver symtab = EMADLAbstractSymtab.createSymTabAndTaggingResolver(getModelsPath()); TaggingResolver symtab = EMADLAbstractSymtab.createSymTabAndTaggingResolver(getModelsPath());
EMAComponentInstanceSymbol instance = resolveComponentInstanceSymbol(qualifiedName, symtab);
generateFiles(symtab, instance, symtab, pythonPath, forced);
if (doCompile) {
compile();
}
processedArchitecture = null;
}
private EMAComponentInstanceSymbol resolveComponentInstanceSymbol(String qualifiedName, TaggingResolver symtab) {
EMAComponentSymbol component = symtab.<EMAComponentSymbol>resolve(qualifiedName, EMAComponentSymbol.KIND).orElse(null); EMAComponentSymbol component = symtab.<EMAComponentSymbol>resolve(qualifiedName, EMAComponentSymbol.KIND).orElse(null);
List<String> splitName = Splitters.DOT.splitToList(qualifiedName); List<String> splitName = Splitters.DOT.splitToList(qualifiedName);
...@@ -126,15 +140,7 @@ public class EMADLGenerator { ...@@ -126,15 +140,7 @@ public class EMADLGenerator {
System.exit(1); System.exit(1);
} }
EMAComponentInstanceSymbol instance = component.getEnclosingScope().<EMAComponentInstanceSymbol>resolve(instanceName, EMAComponentInstanceSymbol.KIND).get(); return component.getEnclosingScope().<EMAComponentInstanceSymbol>resolve(instanceName, EMAComponentInstanceSymbol.KIND).get();
generateFiles(symtab, instance, symtab, pythonPath, forced);
if (doCompile) {
compile();
}
processedArchitecture = null;
} }
public void compile() throws IOException { public void compile() throws IOException {
...@@ -530,7 +536,32 @@ public class EMADLGenerator { ...@@ -530,7 +536,32 @@ public class EMADLGenerator {
final String fullConfigName = String.join(".", names); final String fullConfigName = String.join(".", names);
ArchitectureSymbol correspondingArchitecture = this.processedArchitecture.get(fullConfigName); ArchitectureSymbol correspondingArchitecture = this.processedArchitecture.get(fullConfigName);
assert correspondingArchitecture != null : "No architecture found for train " + fullConfigName + " configuration!"; assert correspondingArchitecture != null : "No architecture found for train " + fullConfigName + " configuration!";
configuration.setTrainedArchitecture(new ArchitectureAdapter(correspondingArchitecture)); configuration.setTrainedArchitecture(
new ArchitectureAdapter(correspondingArchitecture.getName(), correspondingArchitecture));
CNNTrainCocos.checkTrainedArchitectureCoCos(configuration);
// Resolve critic network if critic is present
if (configuration.getCriticName().isPresent()) {
String fullCriticName = configuration.getCriticName().get();
int indexOfFirstNameCharacter = fullCriticName.lastIndexOf('.') + 1;
fullCriticName = fullCriticName.substring(0, indexOfFirstNameCharacter)
+ fullCriticName.substring(indexOfFirstNameCharacter, indexOfFirstNameCharacter + 1).toUpperCase()
+ fullCriticName.substring(indexOfFirstNameCharacter + 1);
TaggingResolver symtab = EMADLAbstractSymtab.createSymTabAndTaggingResolver(getModelsPath());
EMAComponentInstanceSymbol instanceSymbol = resolveComponentInstanceSymbol(fullCriticName, symtab);
EMADLCocos.checkAll(instanceSymbol);
Optional<ArchitectureSymbol> critic = instanceSymbol.getSpannedScope().resolve("", ArchitectureSymbol.KIND);
if (!critic.isPresent()) {
Log.error("During the resolving of critic component: Critic component "
+ fullCriticName + " does not have a CNN implementation but is required to have one");
System.exit(-1);
}
critic.get().setComponentName(fullCriticName);
configuration.setCriticNetwork(new ArchitectureAdapter(fullCriticName, critic.get()));
CNNTrainCocos.checkCriticCocos(configuration);
}
cnnTrainGenerator.setInstanceName(componentInstance.getFullName().replaceAll("\\.", "_")); cnnTrainGenerator.setInstanceName(componentInstance.getFullName().replaceAll("\\.", "_"));
Map<String, String> fileContentMap = cnnTrainGenerator.generateStrings(configuration); Map<String, String> fileContentMap = cnnTrainGenerator.generateStrings(configuration);
......
package de.monticore.lang.monticar.emadl.generator; package de.monticore.lang.monticar.emadl.generator.reinforcementlearning;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol; import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator; import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
import de.monticore.lang.monticar.emadl.generator.EMADLAbstractSymtab;
import de.monticore.lang.monticar.generator.cpp.GeneratorEMAMOpt2CPP; import de.monticore.lang.monticar.generator.cpp.GeneratorEMAMOpt2CPP;
import de.monticore.lang.tagging._symboltable.TaggingResolver; import de.monticore.lang.tagging._symboltable.TaggingResolver;
import de.se_rwth.commons.logging.Log; import de.se_rwth.commons.logging.Log;
...@@ -9,30 +10,49 @@ import de.se_rwth.commons.logging.Log; ...@@ -9,30 +10,49 @@ import de.se_rwth.commons.logging.Log;
import java.io.IOException; import java.io.IOException;
import java.util.Optional; import java.util.Optional;
public class RewardFunctionCppGenerator implements RewardFunctionSourceGenerator { public class RewardFunctionCppGenerator implements RewardFunctionSourceGenerator{
public RewardFunctionCppGenerator() { public RewardFunctionCppGenerator() {
} }
@Override
public void generate(String modelPath, String rootModel, String targetPath) {
GeneratorEMAMOpt2CPP generator = new GeneratorEMAMOpt2CPP();
generator.useArmadilloBackend();
TaggingResolver taggingResolver = EMADLAbstractSymtab.createSymTabAndTaggingResolver(modelPath); @Override
public EMAComponentInstanceSymbol resolveSymbol(TaggingResolver taggingResolver, String rootModel) {
Optional<EMAComponentInstanceSymbol> instanceSymbol = taggingResolver Optional<EMAComponentInstanceSymbol> instanceSymbol = taggingResolver
.<EMAComponentInstanceSymbol>resolve(rootModel, EMAComponentInstanceSymbol.KIND); .<EMAComponentInstanceSymbol>resolve(rootModel, EMAComponentInstanceSymbol.KIND);
if (!instanceSymbol.isPresent()) { if (!instanceSymbol.isPresent()) {
Log.error("Generation of reward function is not possible: Cannot resolve component instance " Log.error("Generation of reward function is not possible: Cannot resolve component instance "
+ rootModel); + rootModel);
} }
return instanceSymbol.get();
}
@Override
public void generate(EMAComponentInstanceSymbol componentInstanceSymbol, TaggingResolver taggingResolver,
String targetPath) {
GeneratorEMAMOpt2CPP generator = new GeneratorEMAMOpt2CPP();
generator.useArmadilloBackend();
generator.setGenerationTargetPath(targetPath); generator.setGenerationTargetPath(targetPath);
try { try {
generator.generate(instanceSymbol.get(), taggingResolver); generator.generate(componentInstanceSymbol, taggingResolver);
} catch (IOException e) { } catch (IOException e) {
Log.error("Generation of reward function is not possible: " + e.getMessage()); Log.error("Generation of reward function is not possible: " + e.getMessage());
} }
}
@Override
public void generate(String modelPath, String rootModel, String targetPath) {
TaggingResolver taggingResolver = createTaggingResolver(modelPath);
EMAComponentInstanceSymbol instanceSymbol = resolveSymbol(taggingResolver, rootModel);
generate(instanceSymbol, taggingResolver, targetPath);
}
@Override
public TaggingResolver createTaggingResolver(final String modelPath) {
return EMADLAbstractSymtab.createSymTabAndTaggingResolver(modelPath);
} }
} }
...@@ -275,8 +275,8 @@ public class GenerationTest extends AbstractSymtabTest { ...@@ -275,8 +275,8 @@ public class GenerationTest extends AbstractSymtabTest {
"HelperA.h", "HelperA.h",
"start_training.sh", "start_training.sh",
"reinforcement_learning/__init__.py", "reinforcement_learning/__init__.py",
"reinforcement_learning/CNNCreator_MountaincarCritic.py", "reinforcement_learning/CNNCreator_mountaincar_agent_mountaincarCritic.py",
"reinforcement_learning/CNNNet_MountaincarCritic.py", "reinforcement_learning/CNNNet_mountaincar_agent_mountaincarCritic.py",
"reinforcement_learning/strategy.py", "reinforcement_learning/strategy.py",
"reinforcement_learning/agent.py", "reinforcement_learning/agent.py",
"reinforcement_learning/environment.py", "reinforcement_learning/environment.py",
......
...@@ -17,7 +17,7 @@ configuration CartPoleDQN { ...@@ -17,7 +17,7 @@ configuration CartPoleDQN {
use_double_dqn : false use_double_dqn : false
loss : euclidean loss : huber
replay_memory : buffer{ replay_memory : buffer{
memory_size : 10000 memory_size : 10000
......
implementation Critic(state, action) {
(state ->
FullyConnected(units=400) ->
Relu() ->
FullyConnected(units=300)
|
action ->
FullyConnected(units=300)
) ->
Add() ->
Relu();
}
\ No newline at end of file
package mountaincar.agent;
component MountaincarCritic {
ports
in Q^{2} state,
in Q(-1:1)^{1} action,
out Q(-oo:oo)^{1} qvalues;
implementation CNN {
(
state ->
FullyConnected(units=400) ->
Relu() ->
FullyConnected(units=300)
|
action ->
FullyConnected(units=300)
) ->
Add() ->
Relu() ->
FullyConnected(units=1) ->
qvalues;
}
}
\ No newline at end of file
...@@ -23,7 +23,7 @@ configuration TorcsDQN { ...@@ -23,7 +23,7 @@ configuration TorcsDQN {
use_double_dqn : true use_double_dqn : true
loss : euclidean loss : huber
replay_memory : buffer{ replay_memory : buffer{
memory_size : 1000000 memory_size : 1000000
......
...@@ -3,8 +3,9 @@ import h5py ...@@ -3,8 +3,9 @@ import h5py
import mxnet as mx import mxnet as mx
import logging import logging
import sys import sys
from mxnet import nd
class cartpole_master_dqnDataLoader: class CNNDataLoader_cartpole_master_dqn:
_input_names_ = ['state'] _input_names_ = ['state']
_output_names_ = ['qvalues_label'] _output_names_ = ['qvalues_label']
...@@ -14,21 +15,38 @@ class cartpole_master_dqnDataLoader: ...@@ -14,21 +15,38 @@ class cartpole_master_dqnDataLoader:
def load_data(self, batch_size): def load_data(self, batch_size):
train_h5, test_h5 = self.load_h5_files() train_h5, test_h5 = self.load_h5_files()
data_mean = train_h5[self._input_names_[0]][:].mean(axis=0) train_data = {}
data_std = train_h5[self._input_names_[0]][:].std(axis=0) + 1e-5 data_mean = {}
data_std = {}
for input_name in self._input_names_:
train_data[input_name] = train_h5[input_name]
data_mean[input_name] = nd.array(train_h5[input_name][:].mean(axis=0))
data_std[input_name] = nd.array(train_h5[input_name][:].std(axis=0) + 1e-5)
train_label = {}
for output_name in self._output_names_:
train_label[output_name] = train_h5[output_name]
train_iter = mx.io.NDArrayIter(data=train_data,
label=train_label,
batch_size=batch_size)
train_iter = mx.io.NDArrayIter(train_h5[self._input_names_[0]],
train_h5[self._output_names_[0]],
batch_size=batch_size,
data_name=self._input_names_[0],
label_name=self._output_names_[0])
test_iter = None test_iter = None
if test_h5 != None: if test_h5 != None:
test_iter = mx.io.NDArrayIter(test_h5[self._input_names_[0]], test_data = {}
test_h5[self._output_names_[0]], for input_name in self._input_names_:
batch_size=batch_size, test_data[input_name] = test_h5[input_name]
data_name=self._input_names_[0],
label_name=self._output_names_[0]) test_label = {}
for output_name in self._output_names_:
test_label[output_name] = test_h5[output_name]
test_iter = mx.io.NDArrayIter(data=test_data,
label=test_label,
batch_size=batch_size)
return train_iter, test_iter, data_mean, data_std return train_iter, test_iter, data_mean, data_std
def load_h5_files(self): def load_h5_files(self):
...@@ -36,21 +54,39 @@ class cartpole_master_dqnDataLoader: ...@@ -36,21 +54,39 @@ class cartpole_master_dqnDataLoader:
test_h5 = None test_h5 = None
train_path = self._data_dir + "train.h5" train_path = self._data_dir + "train.h5"
test_path = self._data_dir + "test.h5" test_path = self._data_dir + "test.h5"
if os.path.isfile(train_path): if os.path.isfile(train_path):
train_h5 = h5py.File(train_path, 'r') train_h5 = h5py.File(train_path, 'r')
if not (self._input_names_[0] in train_h5 and self._output_names_[0] in train_h5):
logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the datasets: " for input_name in self._input_names_:
+ "'" + self._input_names_[0] + "', '" + self._output_names_[0] + "'") if not input_name in train_h5:
sys.exit(1) logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
test_iter = None + "'" + input_name + "'")
sys.exit(1)
for output_name in self._output_names_:
if not output_name in train_h5:
logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
+ "'" + output_name + "'")
sys.exit(1)
if os.path.isfile(test_path): if os.path.isfile(test_path):
test_h5 = h5py.File(test_path, 'r') test_h5 = h5py.File(test_path, 'r')
if not (self._input_names_[0] in test_h5 and self._output_names_[0] in test_h5):
logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the datasets: " for input_name in self._input_names_:
+ "'" + self._input_names_[0] + "', '" + self._output_names_[0] + "'") if not input_name in test_h5:
sys.exit(1) logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
+ "'" + input_name + "'")
sys.exit(1)
for output_name in self._output_names_:
if not output_name in test_h5:
logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
+ "'" + output_name + "'")
sys.exit(1)
else: else:
logging.warning("Couldn't load test set. File '" + os.path.abspath(test_path) + "' does not exist.") logging.warning("Couldn't load test set. File '" + os.path.abspath(test_path) + "' does not exist.")
return train_h5, test_h5 return train_h5, test_h5
else: else:
logging.error("Data loading failure. File '" + os.path.abspath(train_path) + "' does not exist.") logging.error("Data loading failure. File '" + os.path.abspath(train_path) + "' does not exist.")
......
...@@ -101,7 +101,6 @@ class Net_0(gluon.HybridBlock): ...@@ -101,7 +101,6 @@ class Net_0(gluon.HybridBlock):
self.fc3_ = gluon.nn.Dense(units=2, use_bias=True) self.fc3_ = gluon.nn.Dense(units=2, use_bias=True)
# fc3_, output shape: {[2,1,1]} # fc3_, output shape: {[2,1,1]}
self.last_layers['qvalues'] = 'linear'
def hybrid_forward(self, F, state): def hybrid_forward(self, F, state):
......
...@@ -56,7 +56,7 @@ if __name__ == "__main__": ...@@ -56,7 +56,7 @@ if __name__ == "__main__":
'memory_size': 10000, 'memory_size': 10000,
'sample_size': 32, 'sample_size': 32,
'state_dtype': 'float32', 'state_dtype': 'float32',
'action_dtype': 'float32', 'action_dtype': 'uint8',
'rewards_dtype': 'float32' 'rewards_dtype': 'float32'
}, },
'strategy_params': { 'strategy_params': {
...@@ -78,10 +78,10 @@ if __name__ == "__main__": ...@@ -78,10 +78,10 @@ if __name__ == "__main__":
'snapshot_interval': 20, 'snapshot_interval': 20,
'max_episode_step': 250, 'max_episode_step': 250,
'target_score': 185.5, 'target_score': 185.5,
'qnet':qnet_creator.net, 'qnet':qnet_creator.networks[0],
'use_fix_target': True, 'use_fix_target': True,
'target_update_interval': 200, 'target_update_interval': 200,
'loss_function': 'euclidean', 'loss_function': 'huber',
'optimizer': 'rmsprop', 'optimizer': 'rmsprop',
'optimizer_params': { 'optimizer_params': {
'learning_rate': 0.001 }, 'learning_rate': 0.001 },
...@@ -108,4 +108,4 @@ if __name__ == "__main__": ...@@ -108,4 +108,4 @@ if __name__ == "__main__":
train_successful = agent.train() train_successful = agent.train()
if train_successful: if train_successful:
agent.save_best_network(qnet_creator._model_dir_ + qnet_creator._model_prefix_ + '_0_newest', epoch=0) agent.export_best_network(path=qnet_creator._model_dir_ + qnet_creator._model_prefix_ + '_0_newest', epoch=0)
...@@ -114,6 +114,8 @@ class Agent(object): ...@@ -114,6 +114,8 @@ class Agent(object):
agent_session_file = os.path.join(session_dir, 'agent.p') agent_session_file = os.path.join(session_dir, 'agent.p')
logger = self._logger logger = self._logger
self._training_stats.save_stats(self._output_directory, episode=self._current_episode)
self._make_pickle_ready(session_dir) self._make_pickle_ready(session_dir)
with open(agent_session_file, 'wb') as f: with open(agent_session_file, 'wb') as f:
...@@ -122,10 +124,10 @@ class Agent(object): ...@@ -122,10 +124,10 @@ class Agent(object):
def _make_pickle_ready(self, session_dir): def _make_pickle_ready(self, session_dir):
del self._training_stats.logger del self._training_stats.logger
self._logger = None
self._environment.close() self._environment.close()
self._environment = None self._environment = None
self._save_net(self._best_net, 'best_net', session_dir) self._export_net(self._best_net, 'best_net', filedir=session_dir)
self._logger = None
self._best_net = None self._best_net = None
def _make_config_dict(self): def _make_config_dict(self):
...@@ -177,6 +179,9 @@ class Agent(object): ...@@ -177,6 +179,9 @@ class Agent(object):
return states, actions, rewards, next_states, terminals return states, actions, rewards, next_states, terminals
def evaluate(self, target=None, sample_games=100, verbose=True): def evaluate(self, target=None, sample_games=100, verbose=True):
if sample_games <= 0:
return 0
target = self._target_score if target is None else target target = self._target_score if target is None else target
if target: if target:
target_achieved = 0 target_achieved = 0
...@@ -253,25 +258,22 @@ class Agent(object): ...@@ -253,25 +258,22 @@ class Agent(object):
return self._target_score is not None\ return self._target_score is not None\
and avg_reward > self._target_score and avg_reward > self._target_score
def _save_parameters(self, net, episode=None, filename='dqn-agent-params'): def _export_net(self, net, filename, filedir=None, episode=None):
assert self._output_directory assert self._output_directory
assert isinstance(net, gluon.HybridBlock) assert isinstance(net, gluon.HybridBlock)
make_directory_if_not_exist(self._output_directory) make_directory_if_not_exist(self._output_directory)
filedir = self._output_directory if filedir is None else filedir
filename = os.path.join(filedir, filename)
if(episode is not None): if episode is not None:
self._logger.info(
'Saving model parameters after episode %d' % episode)
filename = filename + '-ep{}'.format(episode) filename = filename + '-ep{}'.format(episode)
else:
self._logger.info('Saving model parameters')
self._save_net(net, filename)
def _save_net(self, net, filename, filedir=None): net.export(filename, epoch=0)
filedir = self._output_directory if filedir is None else filedir net.save_parameters(filename + '.params')
filename = os.path.join(filedir, filename + '.params')
net.save_parameters(filename)
def save_best_network(self, path, epoch=0): def export_best_network(self, path=None, epoch=0):
path = os.path.join(self._output_directory, 'best_network')\
if path is None else path
self._logger.info( self._logger.info(
'Saving best network with average reward of {}'.format( 'Saving best network with average reward of {}'.format(
self._best_avg_score)) self._best_avg_score))
...@@ -367,13 +369,17 @@ class DdpgAgent(Agent): ...@@ -367,13 +369,17 @@ class DdpgAgent(Agent):
def _make_pickle_ready(self, session_dir): def _make_pickle_ready(self, session_dir):
super(DdpgAgent, self)._make_pickle_ready(session_dir) super(DdpgAgent, self)._make_pickle_ready(session_dir)
self._save_net(self._actor, 'actor', session_dir) self._export_net(self._actor, 'current_actor')
self._export_net(self._actor, 'actor', filedir=session_dir)
self._actor = None self._actor = None
self._save_net(self._critic, 'critic', session_dir) self._export_net(self._critic, 'critic', filedir=session_dir)
self._critic = None self._critic = None
self._save_net(self._actor_target, 'actor_target', session_dir) self._export_net(
self._actor_target, 'actor_target', filedir=session_dir)
self._actor_target = None self._actor_target = None
self._save_net(self._critic_target, 'critic_target', session_dir) self._export_net(
self._critic_target, 'critic_target', filedir=session_dir)
self._critic_target = None self._critic_target = None
@classmethod @classmethod
...@@ -441,7 +447,8 @@ class DdpgAgent(Agent): ...@@ -441,7 +447,8 @@ class DdpgAgent(Agent):
return action[0].asnumpy() return action[0].asnumpy()
def save_parameters(self, episode): def save_parameters(self, episode):
self._save_parameters(self._actor, episode=episode) self._export_net(
self._actor, self._agent_name + '_actor', episode=episode)
def train(self, episodes=None): def train(self, episodes=None):
self.save_config_file() self.save_config_file()
...@@ -457,9 +464,9 @@ class DdpgAgent(Agent): ...@@ -457,9 +464,9 @@ class DdpgAgent(Agent):
else: else:
self._training_stats = DdpgTrainingStats(episodes) self._training_stats = DdpgTrainingStats(episodes)
# Initialize target Q' and mu' # Initialize target Q' and mu'
self._actor_target = self._copy_actor() self._actor_target = self._copy_actor()
self._critic_target = self._copy_critic() self._critic_target = self._copy_critic()
# Initialize l2 loss for critic network # Initialize l2 loss for critic network
l2_loss = gluon.loss.L2Loss() l2_loss = gluon.loss.L2Loss()
...@@ -496,6 +503,7 @@ class DdpgAgent(Agent): ...@@ -496,6 +503,7 @@ class DdpgAgent(Agent):
# actor and exploration noise N according to strategy # actor and exploration noise N according to strategy
action = self._strategy.select_action(