Commit 78559a89 authored by Nicola Gatto's avatar Nicola Gatto Committed by Evgeny Kusmenko

Implement ddpg

parent 736a2256
......@@ -8,7 +8,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-gluon-generator</artifactId>
<version>0.2.0-SNAPSHOT</version>
<version>0.2.1-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -16,8 +16,8 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.0-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.0-SNAPSHOT</CNNTrain.version>
<CNNArch2MXNet.version>0.2.14-SNAPSHOT</CNNArch2MXNet.version>
<CNNTrain.version>0.3.2-SNAPSHOT</CNNTrain.version>
<CNNArch2MXNet.version>0.2.15-SNAPSHOT</CNNArch2MXNet.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>0.0.1</EMADL2PythonWrapper.version>
......
......@@ -20,11 +20,13 @@
*/
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.IOSymbol;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArch2MxNet;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.Target;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.TemplateConfiguration;
import de.se_rwth.commons.logging.Log;
import java.util.HashMap;
import java.util.Map;
......@@ -34,35 +36,81 @@ public class CNNArch2Gluon extends CNNArch2MxNet {
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
@Override
public Map<String, String> generateStrings(ArchitectureSymbol architecture){
TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration();
Map<String, String> fileContentMap = compileFileContentMap(architecture);
checkValidGeneration(architecture);
return fileContentMap;
}
public Map<String, String> generateStringsAllowMultipleIO(ArchitectureSymbol architecture, Boolean pythonFilesOnly) {
Map<String, String> fileContentMap;
if (pythonFilesOnly) {
fileContentMap = compilePythonFilesOnlyContentMap(architecture);
} else {
fileContentMap = compileFileContentMap(architecture);
}
checkValidOutputTypes(architecture);
return fileContentMap;
}
private void checkValidOutputTypes(ArchitectureSymbol architecture) {
if (((IOSymbol)architecture.getOutputs().get(0)).getDefinition().getType().getWidth() != 1
|| ((IOSymbol)architecture.getOutputs().get(0)).getDefinition().getType().getHeight() != 1) {
Log.error("This cnn architecture has a multi-dimensional output, which is currently not supported by" +
" the code generator.", architecture.getSourcePosition());
}
}
private Map<String, String> compilePythonFiles(CNNArch2GluonTemplateController controller, ArchitectureSymbol architecture) {
Map<String, String> fileContentMap = new HashMap<>();
CNNArch2GluonTemplateController archTc = new CNNArch2GluonTemplateController(
architecture, templateConfiguration);
Map.Entry<String, String> temp;
temp = archTc.process("CNNPredictor", Target.CPP);
fileContentMap.put(temp.getKey(), temp.getValue());
temp = archTc.process("CNNNet", Target.PYTHON);
temp = controller.process("CNNNet", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
if (architecture.getDataPath() != null) {
temp = archTc.process("CNNDataLoader", Target.PYTHON);
temp = controller.process("CNNDataLoader", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
}
temp = archTc.process("CNNCreator", Target.PYTHON);
temp = controller.process("CNNCreator", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
return fileContentMap;
}
private Map<String, String> compileCppFiles(CNNArch2GluonTemplateController controller) {
Map<String, String> fileContentMap = new HashMap<>();
Map.Entry<String, String> temp;
temp = controller.process("CNNPredictor", Target.CPP);
fileContentMap.put(temp.getKey(), temp.getValue());
temp = archTc.process("execute", Target.CPP);
temp = controller.process("execute", Target.CPP);
fileContentMap.put(temp.getKey().replace(".h", ""), temp.getValue());
temp = archTc.process("CNNBufferFile", Target.CPP);
temp = controller.process("CNNBufferFile", Target.CPP);
fileContentMap.put("CNNBufferFile.h", temp.getValue());
checkValidGeneration(architecture);
return fileContentMap;
}
private Map<String, String> compileFileContentMap(ArchitectureSymbol architecture) {
TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration();
Map<String, String> fileContentMap = new HashMap<>();
CNNArch2GluonTemplateController archTc = new CNNArch2GluonTemplateController(
architecture, templateConfiguration);
fileContentMap.putAll(compilePythonFiles(archTc, architecture));
fileContentMap.putAll(compileCppFiles(archTc));
return fileContentMap;
}
}
private Map<String, String> compilePythonFilesOnlyContentMap(ArchitectureSymbol architecture) {
TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration();
CNNArch2GluonTemplateController archTc = new CNNArch2GluonTemplateController(
architecture, templateConfiguration);
return compilePythonFiles(archTc, architecture);
}
}
\ No newline at end of file
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import com.google.common.collect.Maps;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic.CriticNetworkGenerationPair;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic.CriticNetworkGenerator;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.FunctionParameterChecker;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
......@@ -10,7 +12,9 @@ import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNTrain2MxNet;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.TemplateConfiguration;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod;
import de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm;
import de.monticore.lang.monticar.cnntrain._symboltable.RewardFunctionSymbol;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.lang.monticar.generator.pythonwrapper.GeneratorPythonWrapperStandaloneApi;
......@@ -25,6 +29,7 @@ import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.stream.Collectors;
public class CNNTrain2Gluon extends CNNTrain2MxNet {
private static final String REINFORCEMENT_LEARNING_FRAMEWORK_MODULE = "reinforcement_learning";
......@@ -62,6 +67,15 @@ public class CNNTrain2Gluon extends CNNTrain2MxNet {
public void generate(Path modelsDirPath, String rootModelName) {
ConfigurationSymbol configuration = this.getConfigurationSymbol(modelsDirPath, rootModelName);
if (configuration.getLearningMethod().equals(LearningMethod.REINFORCEMENT)) {
throw new IllegalStateException("Cannot call generate of reinforcement configuration without specifying " +
"the trained architecture");
}
generateFilesFromConfigurationSymbol(configuration);
}
private void generateFilesFromConfigurationSymbol(ConfigurationSymbol configuration) {
Map<String, String> fileContents = this.generateStrings(configuration);
GeneratorCPP genCPP = new GeneratorCPP();
genCPP.setGenerationTargetPath(this.getGenerationTargetPath());
......@@ -78,6 +92,13 @@ public class CNNTrain2Gluon extends CNNTrain2MxNet {
}
}
public void generate(Path modelsDirPath, String rootModelName, TrainedArchitecture trainedArchitecture) {
ConfigurationSymbol configurationSymbol = this.getConfigurationSymbol(modelsDirPath, rootModelName);
configurationSymbol.setTrainedArchitecture(trainedArchitecture);
this.setRootProjectModelsDir(modelsDirPath.toString());
generateFilesFromConfigurationSymbol(configurationSymbol);
}
@Override
public Map<String, String> generateStrings(ConfigurationSymbol configuration) {
TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration();
......@@ -98,8 +119,30 @@ public class CNNTrain2Gluon extends CNNTrain2MxNet {
fileContentMap.put("supervised_trainer.py", cnnSupervisedTrainerContent);
} else if (configData.isReinforcementLearning()) {
final String trainerName = "CNNTrainer_" + getInstanceName();
final RLAlgorithm rlAlgorithm = configData.getRlAlgorithm();
if (rlAlgorithm.equals(RLAlgorithm.DDPG)) {
CriticNetworkGenerator criticNetworkGenerator = new CriticNetworkGenerator();
criticNetworkGenerator.setGenerationTargetPath(
Paths.get(getGenerationTargetPath(), REINFORCEMENT_LEARNING_FRAMEWORK_MODULE).toString());
if (getRootProjectModelsDir().isPresent()) {
criticNetworkGenerator.setRootModelsDir(getRootProjectModelsDir().get());
} else {
Log.error("No root model dir set");
}
CriticNetworkGenerationPair criticNetworkResult
= criticNetworkGenerator.generateCriticNetworkContent(templateConfiguration, configuration);
fileContentMap.putAll(criticNetworkResult.getFileContent().entrySet().stream().collect(Collectors.toMap(
k -> REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/" + k.getKey(),
Map.Entry::getValue))
);
ftlContext.put("criticInstanceName", criticNetworkResult.getCriticNetworkName());
}
ftlContext.put("trainerName", trainerName);
Map<String, String> rlFrameworkContentMap = constructReinforcementLearningFramework(templateConfiguration, ftlContext);
Map<String, String> rlFrameworkContentMap = constructReinforcementLearningFramework(templateConfiguration, ftlContext, rlAlgorithm);
fileContentMap.putAll(rlFrameworkContentMap);
final String reinforcementTrainerContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/Trainer.ftl");
......@@ -161,23 +204,34 @@ public class CNNTrain2Gluon extends CNNTrain2MxNet {
}
private Map<String, String> constructReinforcementLearningFramework(
final TemplateConfiguration templateConfiguration, final Map<String, Object> ftlContext) {
final TemplateConfiguration templateConfiguration,
final Map<String, Object> ftlContext,
RLAlgorithm rlAlgorithm) {
Map<String, String> fileContentMap = Maps.newHashMap();
ftlContext.put("rlFrameworkModule", REINFORCEMENT_LEARNING_FRAMEWORK_MODULE);
final String reinforcementAgentContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/agent/Agent.ftl");
final String loggerContent = templateConfiguration.processTemplate(ftlContext,
"reinforcement/util/Logger.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/cnnarch_logger.py", loggerContent);
final String reinforcementAgentContent = templateConfiguration.processTemplate(ftlContext,
"reinforcement/agent/Agent.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/agent.py", reinforcementAgentContent);
final String reinforcementPolicyContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/agent/ActionPolicy.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/action_policy.py", reinforcementPolicyContent);
final String reinforcementStrategyContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/agent/Strategy.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/strategy.py", reinforcementStrategyContent);
final String replayMemoryContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/agent/ReplayMemory.ftl");
final String replayMemoryContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/agent/ReplayMemory.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/replay_memory.py", replayMemoryContent);
final String environmentContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/environment/Environment.ftl");
final String environmentContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/environment/Environment.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/environment.py", environmentContent);
final String utilContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/util/Util.ftl");
final String utilContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/util/Util.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/util.py", utilContent);
final String initContent = "";
......
package de.monticore.lang.monticar.cnnarch.gluongenerator.annotations;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.IOSymbol;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import de.monticore.lang.monticar.ranges._ast.ASTRange;
import de.monticore.symboltable.CommonSymbol;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static com.google.common.base.Preconditions.checkNotNull;
public class ArchitectureAdapter implements TrainedArchitecture {
private ArchitectureSymbol architectureSymbol;
public ArchitectureAdapter(final ArchitectureSymbol architectureSymbol) {
checkNotNull(architectureSymbol);
this.architectureSymbol = architectureSymbol;
}
@Override
public List<String> getInputs() {
return getIOInputSymbols().stream()
.map(CommonSymbol::getName)
.collect(Collectors.toList());
}
@Override
public List<String> getOutputs() {
return getIOOutputSymbols().stream()
.map(CommonSymbol::getName)
.collect(Collectors.toList());
}
@Override
public Map<String, List<Integer>> getDimensions() {
return getAllIOSymbols().stream().collect(Collectors.toMap(CommonSymbol::getName,
s-> s.getDefinition().getType().getDimensions()));
}
@Override
public Map<String, Range> getRanges() {
return getAllIOSymbols().stream().collect(Collectors.toMap(CommonSymbol::getName,
s -> astRangeToTrainRange(s.getDefinition().getType().getDomain().getRangeOpt().orElse(null))));
}
@Override
public Map<String, String> getTypes() {
return getAllIOSymbols().stream().collect(Collectors.toMap(CommonSymbol::getName,
s -> s.getDefinition().getType().getDomain().getName()));
}
private Range astRangeToTrainRange(final ASTRange range) {
if (range == null || (range.hasNoLowerLimit() && range.hasNoUpperLimit())) {
return Range.withInfinityLimits();
} else if (range.hasNoUpperLimit() && !range.hasNoLowerLimit()) {
double lowerLimit = range.getStartValue().doubleValue();
return Range.withUpperInfinityLimit(lowerLimit);
} else if (!range.hasNoUpperLimit() && range.hasNoLowerLimit()) {
double upperLimit = range.getEndValue().doubleValue();
return Range.withLowerInfinityLimit(upperLimit);
} else {
double lowerLimit = range.getStartValue().doubleValue();
double upperLimit = range.getEndValue().doubleValue();
return Range.withLimits(lowerLimit, upperLimit);
}
}
private List<IOSymbol> getIOOutputSymbols() {
return architectureSymbol.getOutputs();
}
private List<IOSymbol> getIOInputSymbols() {
return architectureSymbol.getInputs();
}
private List<IOSymbol> getAllIOSymbols() {
List<IOSymbol> ioSymbols = new ArrayList<>();
ioSymbols.addAll(getIOOutputSymbols());
ioSymbols.addAll(getIOInputSymbols());
return ioSymbols;
}
}
package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic;
public class CriticNetworkGenerationException extends RuntimeException {
public CriticNetworkGenerationException(String s) {
super("Generation of critic network failed: " + s);
}
}
package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic;
import java.util.Map;
public class CriticNetworkGenerationPair {
private String criticNetworkName;
private Map<String, String> fileContent;
public CriticNetworkGenerationPair(String criticNetworkName, Map<String, String> fileContent) {
this.criticNetworkName = criticNetworkName;
this.fileContent = fileContent;
}
public String getCriticNetworkName() {
return criticNetworkName;
}
public Map<String, String> getFileContent() {
return fileContent;
}
}
......@@ -6,7 +6,7 @@ from CNNNet_${tc.fullArchitectureName} import Net
class ${tc.fileNameWithoutEnding}:
_model_dir_ = "model/${tc.componentName}/"
_model_prefix_ = "model"
_input_shapes_ = [<#list tc.architecture.inputs as input>(${tc.join(input.definition.type.dimensions, ",")},)</#list>]
_input_shapes_ = [<#list tc.architecture.inputs as input>(${tc.join(input.definition.type.dimensions, ",")},)<#if input?has_next>,</#if></#list>]
def __init__(self):
self.weight_initializer = mx.init.Normal()
......@@ -48,7 +48,7 @@ class ${tc.fileNameWithoutEnding}:
self.net = Net(data_mean=data_mean, data_std=data_std)
self.net.collect_params().initialize(self.weight_initializer, ctx=context)
self.net.hybridize()
self.net(mx.nd.zeros((1,)+self._input_shapes_[0], ctx=context))
self.net(<#list tc.architecture.inputs as input>mx.nd.zeros((1,)+self._input_shapes_[${input?index}], ctx=context)<#if input?has_next>,</#if></#list>)
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
......
......@@ -74,5 +74,5 @@ class Net(gluon.HybridBlock):
with self.name_scope():
${tc.include(tc.architecture.body, "ARCHITECTURE_DEFINITION")}
def hybrid_forward(self, F, x):
def hybrid_forward(self, F, <#list tc.architecture.inputs as input>${input}<#if input?has_next>, </#if></#list>):
${tc.include(tc.architecture.body, "FORWARD_FUNCTION")}
\ No newline at end of file
......@@ -2,11 +2,11 @@
<#if mode == "ARCHITECTURE_DEFINITION">
if not data_mean is None:
assert(not data_std is None)
self.input_normalization = ZScoreNormalization(data_mean=data_mean, data_std=data_std)
self.${element.name}_input_normalization = ZScoreNormalization(data_mean=data_mean, data_std=data_std)
else:
self.input_normalization = NoNormalization()
self.${element.name}_input_normalization = NoNormalization()
</#if>
<#if mode == "FORWARD_FUNCTION">
${element.name} = self.input_normalization(x)
${element.name} = self.${element.name}_input_normalization(${element.name})
</#if>
\ No newline at end of file
import numpy as np
class StrategyBuilder(object):
def __init__(self):
pass
def build_by_params(
self,
method='epsgreedy',
epsilon=0.5,
min_epsilon=0.05,
epsilon_decay_method='no',
epsilon_decay=0.0,
epsilon_decay_start=0,
action_dim=None,
action_low=None,
action_high=None,
mu=0.0,
theta=0.5,
sigma=0.3
):
if epsilon_decay_method == 'linear':
decay = LinearDecay(
eps_decay=epsilon_decay, min_eps=min_epsilon,
decay_start=epsilon_decay_start)
else:
decay = NoDecay()
if method == 'epsgreedy':
assert action_dim is not None
assert len(action_dim) == 1
return EpsilonGreedyStrategy(
eps=epsilon, number_of_actions=action_dim[0],
decay_method=decay)
elif method == 'ornstein_uhlenbeck':
assert action_dim is not None
assert action_low is not None
assert action_high is not None
assert mu is not None
assert theta is not None
assert sigma is not None
return OrnsteinUhlenbeckStrategy(
action_dim, action_low, action_high, epsilon, mu, theta,
sigma, decay)
else:
assert action_dim is not None
assert len(action_dim) == 1
return GreedyStrategy()
class BaseDecay(object):
def __init__(self):
pass
def decay(self, *args):
raise NotImplementedError
def __call__(self, *args):
return self.decay(*args)
class NoDecay(BaseDecay):
def __init__(self):
super(NoDecay, self).__init__()
def decay(self, cur_eps, episode):
return cur_eps
class LinearDecay(BaseDecay):
def __init__(self, eps_decay, min_eps=0, decay_start=0):
super(LinearDecay, self).__init__()
self.eps_decay = eps_decay
self.min_eps = min_eps
self.decay_start = decay_start
def decay(self, cur_eps, episode):
if episode < self.decay_start:
return cur_eps
else:
return max(cur_eps - self.eps_decay, self.min_eps)
class BaseStrategy(object):
def __init__(self, decay_method):
self._decay_method = decay_method
def select_action(self, values, decay_method):
raise NotImplementedError
def decay(self, episode):
self.cur_eps = self._decay_method.decay(self.cur_eps, episode)
def reset(self):
pass
class EpsilonGreedyStrategy(BaseStrategy):
def __init__(self, eps, number_of_actions, decay_method):
super(EpsilonGreedyStrategy, self).__init__(decay_method)
self.eps = eps
self.cur_eps = eps
self.__number_of_actions = number_of_actions
def select_action(self, values):
do_exploration = (np.random.rand() < self.cur_eps)
if do_exploration:
action = np.random.randint(low=0, high=self.__number_of_actions)
else:
action = values.asnumpy().argmax()
return action
class GreedyStrategy(BaseStrategy):
def __init__(self):
super(GreedyStrategy, self).__init__(None)
def select_action(self, values):
return values.asnumpy().argmax()
def decay(self):
pass
class OrnsteinUhlenbeckStrategy(BaseStrategy):
"""
Ornstein-Uhlenbeck process: dxt = theta * (mu - xt) * dt + sigma * dWt
where Wt denotes the Wiener process.
"""
def __init__(
self,