Aufgrund einer Störung des s3 Storage, könnten in nächster Zeit folgende GitLab Funktionen nicht zur Verfügung stehen: LFS, Container Registry, Job Artifacs, Uploads (Wiki, Bilder, Projekt-Exporte). Wir bitten um Verständnis. Es wird mit Hochdruck an der Behebung des Problems gearbeitet. Weitere Informationen zur Störung des Object Storage finden Sie hier: https://maintenance.itc.rwth-aachen.de/ticket/status/messages/59-object-storage-pilot

Aufgrund einer Wartung wird GitLab am 03.08. zwischen 8:00 und 9:00 Uhr kurzzeitig nicht zur Verfügung stehen. / Due to maintenance, GitLab will be temporarily unavailable on 03.08. between 8:00 and 9:00 am.

Commit 053bf612 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'implement-ddpg' into 'master'

Implement ddpg

See merge request !15
parents 736a2256 78559a89
Pipeline #148384 passed with stages
in 4 minutes
......@@ -8,7 +8,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-gluon-generator</artifactId>
<version>0.2.0-SNAPSHOT</version>
<version>0.2.1-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -16,8 +16,8 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.0-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.0-SNAPSHOT</CNNTrain.version>
<CNNArch2MXNet.version>0.2.14-SNAPSHOT</CNNArch2MXNet.version>
<CNNTrain.version>0.3.2-SNAPSHOT</CNNTrain.version>
<CNNArch2MXNet.version>0.2.15-SNAPSHOT</CNNArch2MXNet.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>0.0.1</EMADL2PythonWrapper.version>
......
......@@ -20,11 +20,13 @@
*/
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.IOSymbol;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArch2MxNet;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.Target;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.TemplateConfiguration;
import de.se_rwth.commons.logging.Log;
import java.util.HashMap;
import java.util.Map;
......@@ -34,35 +36,81 @@ public class CNNArch2Gluon extends CNNArch2MxNet {
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
@Override
public Map<String, String> generateStrings(ArchitectureSymbol architecture){
TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration();
Map<String, String> fileContentMap = compileFileContentMap(architecture);
checkValidGeneration(architecture);
return fileContentMap;
}
public Map<String, String> generateStringsAllowMultipleIO(ArchitectureSymbol architecture, Boolean pythonFilesOnly) {
Map<String, String> fileContentMap;
if (pythonFilesOnly) {
fileContentMap = compilePythonFilesOnlyContentMap(architecture);
} else {
fileContentMap = compileFileContentMap(architecture);
}
checkValidOutputTypes(architecture);
return fileContentMap;
}
private void checkValidOutputTypes(ArchitectureSymbol architecture) {
if (((IOSymbol)architecture.getOutputs().get(0)).getDefinition().getType().getWidth() != 1
|| ((IOSymbol)architecture.getOutputs().get(0)).getDefinition().getType().getHeight() != 1) {
Log.error("This cnn architecture has a multi-dimensional output, which is currently not supported by" +
" the code generator.", architecture.getSourcePosition());
}
}
private Map<String, String> compilePythonFiles(CNNArch2GluonTemplateController controller, ArchitectureSymbol architecture) {
Map<String, String> fileContentMap = new HashMap<>();
CNNArch2GluonTemplateController archTc = new CNNArch2GluonTemplateController(
architecture, templateConfiguration);
Map.Entry<String, String> temp;
temp = archTc.process("CNNPredictor", Target.CPP);
fileContentMap.put(temp.getKey(), temp.getValue());
temp = archTc.process("CNNNet", Target.PYTHON);
temp = controller.process("CNNNet", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
if (architecture.getDataPath() != null) {
temp = archTc.process("CNNDataLoader", Target.PYTHON);
temp = controller.process("CNNDataLoader", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
}
temp = archTc.process("CNNCreator", Target.PYTHON);
temp = controller.process("CNNCreator", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
return fileContentMap;
}
private Map<String, String> compileCppFiles(CNNArch2GluonTemplateController controller) {
Map<String, String> fileContentMap = new HashMap<>();
Map.Entry<String, String> temp;
temp = controller.process("CNNPredictor", Target.CPP);
fileContentMap.put(temp.getKey(), temp.getValue());
temp = archTc.process("execute", Target.CPP);
temp = controller.process("execute", Target.CPP);
fileContentMap.put(temp.getKey().replace(".h", ""), temp.getValue());
temp = archTc.process("CNNBufferFile", Target.CPP);
temp = controller.process("CNNBufferFile", Target.CPP);
fileContentMap.put("CNNBufferFile.h", temp.getValue());
checkValidGeneration(architecture);
return fileContentMap;
}
private Map<String, String> compileFileContentMap(ArchitectureSymbol architecture) {
TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration();
Map<String, String> fileContentMap = new HashMap<>();
CNNArch2GluonTemplateController archTc = new CNNArch2GluonTemplateController(
architecture, templateConfiguration);
fileContentMap.putAll(compilePythonFiles(archTc, architecture));
fileContentMap.putAll(compileCppFiles(archTc));
return fileContentMap;
}
}
private Map<String, String> compilePythonFilesOnlyContentMap(ArchitectureSymbol architecture) {
TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration();
CNNArch2GluonTemplateController archTc = new CNNArch2GluonTemplateController(
architecture, templateConfiguration);
return compilePythonFiles(archTc, architecture);
}
}
\ No newline at end of file
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import com.google.common.collect.Maps;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic.CriticNetworkGenerationPair;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic.CriticNetworkGenerator;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.FunctionParameterChecker;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
......@@ -10,7 +12,9 @@ import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNTrain2MxNet;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.TemplateConfiguration;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod;
import de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm;
import de.monticore.lang.monticar.cnntrain._symboltable.RewardFunctionSymbol;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.lang.monticar.generator.pythonwrapper.GeneratorPythonWrapperStandaloneApi;
......@@ -25,6 +29,7 @@ import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.stream.Collectors;
public class CNNTrain2Gluon extends CNNTrain2MxNet {
private static final String REINFORCEMENT_LEARNING_FRAMEWORK_MODULE = "reinforcement_learning";
......@@ -62,6 +67,15 @@ public class CNNTrain2Gluon extends CNNTrain2MxNet {
public void generate(Path modelsDirPath, String rootModelName) {
ConfigurationSymbol configuration = this.getConfigurationSymbol(modelsDirPath, rootModelName);
if (configuration.getLearningMethod().equals(LearningMethod.REINFORCEMENT)) {
throw new IllegalStateException("Cannot call generate of reinforcement configuration without specifying " +
"the trained architecture");
}
generateFilesFromConfigurationSymbol(configuration);
}
private void generateFilesFromConfigurationSymbol(ConfigurationSymbol configuration) {
Map<String, String> fileContents = this.generateStrings(configuration);
GeneratorCPP genCPP = new GeneratorCPP();
genCPP.setGenerationTargetPath(this.getGenerationTargetPath());
......@@ -78,6 +92,13 @@ public class CNNTrain2Gluon extends CNNTrain2MxNet {
}
}
public void generate(Path modelsDirPath, String rootModelName, TrainedArchitecture trainedArchitecture) {
ConfigurationSymbol configurationSymbol = this.getConfigurationSymbol(modelsDirPath, rootModelName);
configurationSymbol.setTrainedArchitecture(trainedArchitecture);
this.setRootProjectModelsDir(modelsDirPath.toString());
generateFilesFromConfigurationSymbol(configurationSymbol);
}
@Override
public Map<String, String> generateStrings(ConfigurationSymbol configuration) {
TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration();
......@@ -98,8 +119,30 @@ public class CNNTrain2Gluon extends CNNTrain2MxNet {
fileContentMap.put("supervised_trainer.py", cnnSupervisedTrainerContent);
} else if (configData.isReinforcementLearning()) {
final String trainerName = "CNNTrainer_" + getInstanceName();
final RLAlgorithm rlAlgorithm = configData.getRlAlgorithm();
if (rlAlgorithm.equals(RLAlgorithm.DDPG)) {
CriticNetworkGenerator criticNetworkGenerator = new CriticNetworkGenerator();
criticNetworkGenerator.setGenerationTargetPath(
Paths.get(getGenerationTargetPath(), REINFORCEMENT_LEARNING_FRAMEWORK_MODULE).toString());
if (getRootProjectModelsDir().isPresent()) {
criticNetworkGenerator.setRootModelsDir(getRootProjectModelsDir().get());
} else {
Log.error("No root model dir set");
}
CriticNetworkGenerationPair criticNetworkResult
= criticNetworkGenerator.generateCriticNetworkContent(templateConfiguration, configuration);
fileContentMap.putAll(criticNetworkResult.getFileContent().entrySet().stream().collect(Collectors.toMap(
k -> REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/" + k.getKey(),
Map.Entry::getValue))
);
ftlContext.put("criticInstanceName", criticNetworkResult.getCriticNetworkName());
}
ftlContext.put("trainerName", trainerName);
Map<String, String> rlFrameworkContentMap = constructReinforcementLearningFramework(templateConfiguration, ftlContext);
Map<String, String> rlFrameworkContentMap = constructReinforcementLearningFramework(templateConfiguration, ftlContext, rlAlgorithm);
fileContentMap.putAll(rlFrameworkContentMap);
final String reinforcementTrainerContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/Trainer.ftl");
......@@ -161,23 +204,34 @@ public class CNNTrain2Gluon extends CNNTrain2MxNet {
}
private Map<String, String> constructReinforcementLearningFramework(
final TemplateConfiguration templateConfiguration, final Map<String, Object> ftlContext) {
final TemplateConfiguration templateConfiguration,
final Map<String, Object> ftlContext,
RLAlgorithm rlAlgorithm) {
Map<String, String> fileContentMap = Maps.newHashMap();
ftlContext.put("rlFrameworkModule", REINFORCEMENT_LEARNING_FRAMEWORK_MODULE);
final String reinforcementAgentContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/agent/Agent.ftl");
final String loggerContent = templateConfiguration.processTemplate(ftlContext,
"reinforcement/util/Logger.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/cnnarch_logger.py", loggerContent);
final String reinforcementAgentContent = templateConfiguration.processTemplate(ftlContext,
"reinforcement/agent/Agent.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/agent.py", reinforcementAgentContent);
final String reinforcementPolicyContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/agent/ActionPolicy.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/action_policy.py", reinforcementPolicyContent);
final String reinforcementStrategyContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/agent/Strategy.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/strategy.py", reinforcementStrategyContent);
final String replayMemoryContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/agent/ReplayMemory.ftl");
final String replayMemoryContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/agent/ReplayMemory.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/replay_memory.py", replayMemoryContent);
final String environmentContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/environment/Environment.ftl");
final String environmentContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/environment/Environment.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/environment.py", environmentContent);
final String utilContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/util/Util.ftl");
final String utilContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/util/Util.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/util.py", utilContent);
final String initContent = "";
......
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.ConfigurationData;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.*;
/**
*
......@@ -25,10 +25,19 @@ public class ReinforcementConfigurationData extends ConfigurationData {
private static final String AST_ENTRY_AGENT_NAME = "agent_name";
private static final String AST_ENTRY_USE_DOUBLE_DQN = "use_double_dqn";
private static final String AST_ENTRY_LOSS = "loss";
private static final String AST_ENTRY_RL_ALGORITHM = "rl_algorithm";
private static final String AST_ENTRY_REPLAY_MEMORY = "replay_memory";
private static final String AST_ENTRY_ACTION_SELECTION = "action_selection";
private static final String AST_ENTRY_STRATEGY = "strategy";
private static final String AST_ENTRY_ENVIRONMENT = "environment";
private static final String AST_ENTRY_START_TRAINING_AT = "start_training_at";
private static final String AST_SOFT_TARGET_UPDATE_RATE = "soft_target_update_rate";
private static final String AST_EVALUATION_SAMPLES = "evaluation_samples";
private static final String ENVIRONMENT_PARAM_REWARD_TOPIC = "reward_topic";
private static final String ENVIRONMENT_ROS = "ros_interface";
private static final String ENVIRONMENT_GYM = "gym";
private static final String STRATEGY_ORNSTEIN_UHLENBECK = "ornstein_uhlenbeck";
public ReinforcementConfigurationData(ConfigurationSymbol configuration, String instanceName) {
super(configuration, instanceName);
......@@ -97,6 +106,67 @@ public class ReinforcementConfigurationData extends ConfigurationData {
? null : (Boolean)retrieveConfigurationEntryValueByKey(AST_ENTRY_USE_DOUBLE_DQN);
}
public Double getSoftTargetUpdateRate() {
return !configurationContainsKey(AST_SOFT_TARGET_UPDATE_RATE)
? null : (Double)retrieveConfigurationEntryValueByKey(AST_SOFT_TARGET_UPDATE_RATE);
}
public Integer getStartTrainingAt() {
return !configurationContainsKey(AST_ENTRY_START_TRAINING_AT)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_START_TRAINING_AT);
}
public Integer getEvaluationSamples() {
return !configurationContainsKey(AST_EVALUATION_SAMPLES)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_EVALUATION_SAMPLES);
}
public RLAlgorithm getRlAlgorithm() {
if (!isReinforcementLearning()) {
return null;
}
return !configurationContainsKey(AST_ENTRY_RL_ALGORITHM)
? RLAlgorithm.DQN : (RLAlgorithm)retrieveConfigurationEntryValueByKey(AST_ENTRY_RL_ALGORITHM);
}
public String getInputNameOfTrainedArchitecture() {
if (!this.getConfiguration().getTrainedArchitecture().isPresent()) {
throw new IllegalStateException("No trained architecture set");
}
TrainedArchitecture trainedArchitecture = getConfiguration().getTrainedArchitecture().get();
// We allow only one input, the first one is the only input
return trainedArchitecture.getInputs().get(0);
}
public String getOutputNameOfTrainedArchitecture() {
if (!this.getConfiguration().getTrainedArchitecture().isPresent()) {
throw new IllegalStateException("No trained architecture set");
}
TrainedArchitecture trainedArchitecture = getConfiguration().getTrainedArchitecture().get();
// We allow only one output, the first one is the only output
return trainedArchitecture.getOutputs().get(0);
}
public List<Integer> getStateDim() {
if (!this.getConfiguration().getTrainedArchitecture().isPresent()) {
return null;
}
final String inputName = getInputNameOfTrainedArchitecture();
TrainedArchitecture trainedArchitecture = this.getConfiguration().getTrainedArchitecture().get();
return trainedArchitecture.getDimensions().get(inputName);
}
public List<Integer> getActionDim() {
if (!this.getConfiguration().getTrainedArchitecture().isPresent()) {
return null;
}
final String outputName = getOutputNameOfTrainedArchitecture();
TrainedArchitecture trainedArchitecture = this.getConfiguration().getTrainedArchitecture().get();
return trainedArchitecture.getDimensions().get(outputName);
}
public String getLoss() {
return !configurationContainsKey(AST_ENTRY_LOSS)
? null : retrieveConfigurationEntryValueByKey(AST_ENTRY_LOSS).toString();
......@@ -106,8 +176,36 @@ public class ReinforcementConfigurationData extends ConfigurationData {
return getMultiParamEntry(AST_ENTRY_REPLAY_MEMORY, "method");
}
public Map<String, Object> getActionSelection() {
return getMultiParamEntry(AST_ENTRY_ACTION_SELECTION, "method");
public Map<String, Object> getStrategy() {
assert isReinforcementLearning(): "Strategy parameter only for reinforcement learning but called in a " +
" non reinforcement learning context";
Map<String, Object> strategyParams = getMultiParamEntry(AST_ENTRY_STRATEGY, "method");
if (strategyParams.get("method").equals(STRATEGY_ORNSTEIN_UHLENBECK)) {
assert getConfiguration().getTrainedArchitecture().isPresent(): "Architecture not present," +
" but reinforcement training";
TrainedArchitecture trainedArchitecture = getConfiguration().getTrainedArchitecture().get();
final String actionPortName = getOutputNameOfTrainedArchitecture();
Range actionRange = trainedArchitecture.getRanges().get(actionPortName);
if (actionRange.isLowerLimitInfinity() && actionRange.isUpperLimitInfinity()) {
strategyParams.put("action_low", null);
strategyParams.put("action_high", null);
} else if(!actionRange.isLowerLimitInfinity() && actionRange.isUpperLimitInfinity()) {
assert actionRange.getLowerLimit().isPresent();
strategyParams.put("action_low", actionRange.getLowerLimit().get());
strategyParams.put("action_high", null);
} else if (actionRange.isLowerLimitInfinity() && !actionRange.isUpperLimitInfinity()) {
assert actionRange.getUpperLimit().isPresent();
strategyParams.put("action_low", null);
strategyParams.put("action_high", actionRange.getUpperLimit().get());
} else {
assert actionRange.getLowerLimit().isPresent();
assert actionRange.getUpperLimit().isPresent();
strategyParams.put("action_low", actionRange.getLowerLimit().get());
strategyParams.put("action_high", actionRange.getUpperLimit().get());
}
}
return strategyParams;
}
public Map<String, Object> getEnvironment() {
......@@ -136,6 +234,16 @@ public class ReinforcementConfigurationData extends ConfigurationData {
.getRewardFunctionParameter().orElse(null));
}
public boolean isDiscreteRlAlgorithm() {
assert isReinforcementLearning();
return getRlAlgorithm().equals(RLAlgorithm.DQN);
}
public boolean isContinuousRlAlgorithm() {
assert isReinforcementLearning();
return getRlAlgorithm().equals(RLAlgorithm.DDPG);
}
public Map<String, Object> getRewardFunctionStateParameter() {
if (!getRlRewardFunctionParameter().isPresent()
|| !getRlRewardFunctionParameter().get().getInputStateParameterName().isPresent()) {
......@@ -159,6 +267,50 @@ public class ReinforcementConfigurationData extends ConfigurationData {
return getRlRewardFunctionParameter().get().getOutputParameterName().orElse(null);
}
public String getCriticOptimizerName() {
if (!getConfiguration().getCriticOptimizer().isPresent()) {
return null;
}
return getConfiguration().getCriticOptimizer().get().getName();
}
public Map<String, String> getCriticOptimizerParams() {
// get classes for single enum values
assert getConfiguration().getCriticOptimizer().isPresent():
"Critic optimizer params called although, not present";
List<Class> lrPolicyClasses = new ArrayList<>();
for (LRPolicy enum_value: LRPolicy.values()) {
lrPolicyClasses.add(enum_value.getClass());
}
Map<String, String> mapToStrings = new HashMap<>();
Map<String, OptimizerParamSymbol> optimizerParams =
getConfiguration().getCriticOptimizer().get().getOptimizerParamMap();
for (Map.Entry<String, OptimizerParamSymbol> entry : optimizerParams.entrySet()) {
String paramName = entry.getKey();
String valueAsString = entry.getValue().toString();
Class realClass = entry.getValue().getValue().getValue().getClass();
if (realClass == Boolean.class) {
valueAsString = (Boolean) entry.getValue().getValue().getValue() ? "True" : "False";
} else if (lrPolicyClasses.contains(realClass)) {
valueAsString = "'" + valueAsString + "'";
}
mapToStrings.put(paramName, valueAsString);
}
return mapToStrings;
}
public boolean hasRosRewardTopic() {
Map<String, Object> environmentParameters = getMultiParamEntry(AST_ENTRY_ENVIRONMENT, "environment");
if (environmentParameters == null
|| !environmentParameters.containsKey("environment")) {
return false;
}
return environmentParameters.containsKey(ENVIRONMENT_PARAM_REWARD_TOPIC);
}
private Map<String, Object> getMultiParamEntry(final String key, final String valueName) {
if (!configurationContainsKey(key)) {
return null;
......
package de.monticore.lang.monticar.cnnarch.gluongenerator.annotations;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.IOSymbol;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import de.monticore.lang.monticar.ranges._ast.ASTRange;
import de.monticore.symboltable.CommonSymbol;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static com.google.common.base.Preconditions.checkNotNull;
public class ArchitectureAdapter implements TrainedArchitecture {
private ArchitectureSymbol architectureSymbol;
public ArchitectureAdapter(final ArchitectureSymbol architectureSymbol) {
checkNotNull(architectureSymbol);
this.architectureSymbol = architectureSymbol;
}
@Override
public List<String> getInputs() {
return getIOInputSymbols().stream()
.map(CommonSymbol::getName)
.collect(Collectors.toList());
}
@Override
public List<String> getOutputs() {
return getIOOutputSymbols().stream()
.map(CommonSymbol::getName)
.collect(Collectors.toList());
}
@Override
public Map<String, List<Integer>> getDimensions() {
return getAllIOSymbols().stream().collect(Collectors.toMap(CommonSymbol::getName,
s-> s.getDefinition().getType().getDimensions()));
}
@Override
public Map<String, Range> getRanges() {
return getAllIOSymbols().stream().collect(Collectors.toMap(CommonSymbol::getName,
s -> astRangeToTrainRange(s.getDefinition().getType().getDomain().getRangeOpt().orElse(null))));
}
@Override
public Map<String, String> getTypes() {
return getAllIOSymbols().stream().collect(Collectors.toMap(CommonSymbol::getName,
s -> s.getDefinition().getType().getDomain().getName()));
}
private Range astRangeToTrainRange(final ASTRange range) {
if (range == null || (range.hasNoLowerLimit() && range.hasNoUpperLimit())) {
return Range.withInfinityLimits();
} else if (range.hasNoUpperLimit() && !range.hasNoLowerLimit()) {
double lowerLimit = range.getStartValue().doubleValue();
return Range.withUpperInfinityLimit(lowerLimit);
} else if (!range.hasNoUpperLimit() && range.hasNoLowerLimit()) {
double upperLimit = range.getEndValue().doubleValue();
return Range.withLowerInfinityLimit(upperLimit);