Commit d83c8f53 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'added-trainer' into 'master'

Added trainer

See merge request !6
parents 29d9750a 23eb5106
Pipeline #69519 passed with stages
in 1 minute and 56 seconds
...@@ -8,14 +8,15 @@ ...@@ -8,14 +8,15 @@
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-mxnet-generator</artifactId> <artifactId>cnnarch-mxnet-generator</artifactId>
<version>0.2.1-SNAPSHOT</version> <version>0.2.2-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= --> <!-- == PROJECT DEPENDENCIES ============================================= -->
<properties> <properties>
<!-- .. SE-Libraries .................................................. --> <!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.2.1-SNAPSHOT</CNNArch.version> <CNNArch.version>0.2.3-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.2.4-SNAPSHOT</CNNTrain.version>
<!-- .. Libraries .................................................. --> <!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version> <guava.version>18.0</guava.version>
...@@ -70,6 +71,20 @@ ...@@ -70,6 +71,20 @@
<scope>provided</scope> <scope>provided</scope>
</dependency> </dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>${CNNTrain.version}</version>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>${CNNTrain.version}</version>
<classifier>${grammars.classifier}</classifier>
<scope>provided</scope>
</dependency>
<!-- .. Test Libraries ............................................... --> <!-- .. Test Libraries ............................................... -->
<dependency> <dependency>
...@@ -127,7 +142,7 @@ ...@@ -127,7 +142,7 @@
<configuration> <configuration>
<archive> <archive>
<manifest> <manifest>
<mainClass>de.monticore.lang.monticar.cnnarch.generator.CNNArchGeneratorCli</mainClass> <mainClass>de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArch2MxNetCli</mainClass>
</manifest> </manifest>
</archive> </archive>
<descriptorRefs> <descriptorRefs>
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>. * License along with this project. If not, see <http://www.gnu.org/licenses/>.
* ******************************************************************************* * *******************************************************************************
*/ */
package de.monticore.lang.monticar.cnnarch.generator; package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
......
...@@ -18,13 +18,15 @@ ...@@ -18,13 +18,15 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>. * License along with this project. If not, see <http://www.gnu.org/licenses/>.
* ******************************************************************************* * *******************************************************************************
*/ */
package de.monticore.lang.monticar.cnnarch.generator; package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.io.paths.ModelPath; import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos; import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage; import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.symboltable.GlobalScope; import de.monticore.symboltable.GlobalScope;
import de.monticore.symboltable.Scope; import de.monticore.symboltable.Scope;
import de.se_rwth.commons.logging.Log; import de.se_rwth.commons.logging.Log;
...@@ -33,15 +35,13 @@ import java.io.File; ...@@ -33,15 +35,13 @@ import java.io.File;
import java.io.FileWriter; import java.io.FileWriter;
import java.io.IOException; import java.io.IOException;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.HashMap; import java.util.*;
import java.util.Map;
import java.util.Optional;
public class CNNArchGenerator { public class CNNArch2MxNet implements CNNArchGenerator {
private String generationTargetPath; private String generationTargetPath;
public CNNArchGenerator() { public CNNArch2MxNet() {
setGenerationTargetPath("./target/generated-sources-cnnarch/"); setGenerationTargetPath("./target/generated-sources-cnnarch/");
} }
...@@ -79,6 +79,24 @@ public class CNNArchGenerator { ...@@ -79,6 +79,24 @@ public class CNNArchGenerator {
} }
} }
@Override
public Map<String, String> generateTrainer(List<ConfigurationSymbol> configurations, List<String> instanceNames, String mainComponentName) {
int numberOfNetworks = configurations.size();
if (configurations.size() != instanceNames.size()){
throw new IllegalStateException(
"The number of configurations and the number of instances for generation of the CNNTrainer is not equal. " +
"This should have been checked previously.");
}
List<ConfigurationData> configDataList = new ArrayList<>();
for(int i = 0; i < numberOfNetworks; i++){
configDataList.add(new ConfigurationData(configurations.get(i), instanceNames.get(i)));
}
Map<String, Object> ftlContext = Collections.singletonMap("configurations", configDataList);
return Collections.singletonMap(
"CNNTrainer_" + mainComponentName + ".py",
TemplateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl"));
}
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method. //check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public Map<String, String> generateStrings(ArchitectureSymbol architecture){ public Map<String, String> generateStrings(ArchitectureSymbol architecture){
Map<String, String> fileContentMap = new HashMap<>(); Map<String, String> fileContentMap = new HashMap<>();
...@@ -104,21 +122,19 @@ public class CNNArchGenerator { ...@@ -104,21 +122,19 @@ public class CNNArchGenerator {
private void checkValidGeneration(ArchitectureSymbol architecture){ private void checkValidGeneration(ArchitectureSymbol architecture){
if (architecture.getInputs().size() > 1){ if (architecture.getInputs().size() > 1){
Log.warn("This cnn architecture has multiple inputs, " + Log.error("This cnn architecture has multiple inputs, " +
"which is currently not supported by the generator. " + "which is currently not supported by the mxnetgenerator. "
"The generated code will not work correctly."
, architecture.getSourcePosition()); , architecture.getSourcePosition());
} }
if (architecture.getOutputs().size() > 1){ if (architecture.getOutputs().size() > 1){
Log.warn("This cnn architecture has multiple outputs, " + Log.error("This cnn architecture has multiple outputs, " +
"which is currently not supported by the generator. " + "which is currently not supported by the mxnetgenerator. "
"The generated code will not work correctly."
, architecture.getSourcePosition()); , architecture.getSourcePosition());
} }
if (architecture.getOutputs().get(0).getDefinition().getType().getWidth() != 1 || if (architecture.getOutputs().get(0).getDefinition().getType().getWidth() != 1 ||
architecture.getOutputs().get(0).getDefinition().getType().getHeight() != 1){ architecture.getOutputs().get(0).getDefinition().getType().getHeight() != 1){
Log.error("This cnn architecture has a multi-dimensional output, " + Log.error("This cnn architecture has a multi-dimensional output, " +
"which is currently not supported by the generator." "which is currently not supported by the mxnetgenerator."
, architecture.getSourcePosition()); , architecture.getSourcePosition());
} }
} }
...@@ -143,5 +159,4 @@ public class CNNArchGenerator { ...@@ -143,5 +159,4 @@ public class CNNArchGenerator {
writer.close(); writer.close();
} }
} }
} }
...@@ -18,14 +18,14 @@ ...@@ -18,14 +18,14 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>. * License along with this project. If not, see <http://www.gnu.org/licenses/>.
* ******************************************************************************* * *******************************************************************************
*/ */
package de.monticore.lang.monticar.cnnarch.generator; package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import org.apache.commons.cli.*; import org.apache.commons.cli.*;
import java.nio.file.Path; import java.nio.file.Path;
import java.nio.file.Paths; import java.nio.file.Paths;
public class CNNArchGeneratorCli { public class CNNArch2MxNetCli {
public static final Option OPTION_MODELS_PATH = Option.builder("m") public static final Option OPTION_MODELS_PATH = Option.builder("m")
.longOpt("models-dir") .longOpt("models-dir")
...@@ -48,7 +48,7 @@ public class CNNArchGeneratorCli { ...@@ -48,7 +48,7 @@ public class CNNArchGeneratorCli {
.required(false) .required(false)
.build(); .build();
private CNNArchGeneratorCli() { private CNNArch2MxNetCli() {
} }
public static void main(String[] args) { public static void main(String[] args) {
...@@ -84,7 +84,7 @@ public class CNNArchGeneratorCli { ...@@ -84,7 +84,7 @@ public class CNNArchGeneratorCli {
Path modelsDirPath = Paths.get(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt())); Path modelsDirPath = Paths.get(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt()));
String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt()); String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt());
String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt()); String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt());
CNNArchGenerator generator = new CNNArchGenerator(); CNNArch2MxNet generator = new CNNArch2MxNet();
if (outputPath != null){ if (outputPath != null){
generator.setGenerationTargetPath(outputPath); generator.setGenerationTargetPath(outputPath);
} }
......
...@@ -18,17 +18,12 @@ ...@@ -18,17 +18,12 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>. * License along with this project. If not, see <http://www.gnu.org/licenses/>.
* ******************************************************************************* * *******************************************************************************
*/ */
package de.monticore.lang.monticar.cnnarch.generator; package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.*; import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.Sigmoid; import de.monticore.lang.monticar.cnnarch.predefined.Sigmoid;
import de.monticore.lang.monticar.cnnarch.predefined.Softmax; import de.monticore.lang.monticar.cnnarch.predefined.Softmax;
import de.se_rwth.commons.logging.Log;
import freemarker.template.Configuration;
import freemarker.template.Template;
import freemarker.template.TemplateException;
import java.io.IOException;
import java.io.StringWriter; import java.io.StringWriter;
import java.io.Writer; import java.io.Writer;
import java.util.*; import java.util.*;
...@@ -41,14 +36,15 @@ public class CNNArchTemplateController { ...@@ -41,14 +36,15 @@ public class CNNArchTemplateController {
public static final String ELEMENT_DATA_KEY = "element"; public static final String ELEMENT_DATA_KEY = "element";
private LayerNameCreator nameManager; private LayerNameCreator nameManager;
private Configuration freemarkerConfig = TemplateConfiguration.get();
private ArchitectureSymbol architecture; private ArchitectureSymbol architecture;
//temporary attributes. They are set after calling process()
private Writer writer; private Writer writer;
private String mainTemplateNameWithoutEnding; private String mainTemplateNameWithoutEnding;
private Target targetLanguage; private Target targetLanguage;
private ArchitectureElementData dataElement; private ArchitectureElementData dataElement;
public CNNArchTemplateController(ArchitectureSymbol architecture) { public CNNArchTemplateController(ArchitectureSymbol architecture) {
setArchitecture(architecture); setArchitecture(architecture);
} }
...@@ -57,14 +53,6 @@ public class CNNArchTemplateController { ...@@ -57,14 +53,6 @@ public class CNNArchTemplateController {
return mainTemplateNameWithoutEnding + "_" + getFullArchitectureName(); return mainTemplateNameWithoutEnding + "_" + getFullArchitectureName();
} }
public Target getTargetLanguage(){
return targetLanguage;
}
public void setTargetLanguage(Target targetLanguage) {
this.targetLanguage = targetLanguage;
}
public ArchitectureElementData getCurrentElement() { public ArchitectureElementData getCurrentElement() {
return dataElement; return dataElement;
} }
...@@ -137,25 +125,10 @@ public class CNNArchTemplateController { ...@@ -137,25 +125,10 @@ public class CNNArchTemplateController {
public void include(String relativePath, String templateWithoutFileEnding, Writer writer){ public void include(String relativePath, String templateWithoutFileEnding, Writer writer){
String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING; String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING;
Map<String, Object> ftlContext = new HashMap<>();
try { ftlContext.put(TEMPLATE_CONTROLLER_KEY, this);
Template template = freemarkerConfig.getTemplate(templatePath); ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement());
Map<String, Object> ftlContext = new HashMap<>(); TemplateConfiguration.processTemplate(ftlContext, templatePath, writer);
ftlContext.put(TEMPLATE_CONTROLLER_KEY, this);
ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement());
this.writer = writer;
template.process(ftlContext, writer);
this.writer = null;
}
catch (IOException e) {
Log.error("Freemarker could not find template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
}
catch (TemplateException e){
Log.error("An exception occured in template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
}
} }
public void include(IOSymbol ioElement, Writer writer){ public void include(IOSymbol ioElement, Writer writer){
...@@ -229,18 +202,16 @@ public class CNNArchTemplateController { ...@@ -229,18 +202,16 @@ public class CNNArchTemplateController {
StringWriter writer = new StringWriter(); StringWriter writer = new StringWriter();
this.mainTemplateNameWithoutEnding = templateNameWithoutEnding; this.mainTemplateNameWithoutEnding = templateNameWithoutEnding;
this.targetLanguage = targetLanguage; this.targetLanguage = targetLanguage;
include("", templateNameWithoutEnding, writer); this.writer = writer;
include("", templateNameWithoutEnding, writer);
String fileEnding = targetLanguage.toString(); String fileEnding = targetLanguage.toString();
if (targetLanguage == Target.CPP){
fileEnding = ".h";
}
String fileName = getFileNameWithoutEnding() + fileEnding; String fileName = getFileNameWithoutEnding() + fileEnding;
Map.Entry<String,String> fileContent = new AbstractMap.SimpleEntry<>(fileName, writer.toString()); Map.Entry<String,String> fileContent = new AbstractMap.SimpleEntry<>(fileName, writer.toString());
this.mainTemplateNameWithoutEnding = null; this.mainTemplateNameWithoutEnding = null;
this.targetLanguage = null; this.targetLanguage = null;
this.writer = null;
return fileContent; return fileContent;
} }
......
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class ConfigurationData {
ConfigurationSymbol configuration;
String instanceName;
public ConfigurationData(ConfigurationSymbol configuration, String instanceName) {
this.configuration = configuration;
this.instanceName = instanceName;
}
public ConfigurationSymbol getConfiguration() {
return configuration;
}
public String getInstanceName() {
return instanceName;
}
public String getNumEpoch() {
if (!getConfiguration().getEntryMap().containsKey("num_epoch")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("num_epoch").getValue());
}
public String getBatchSize() {
if (!getConfiguration().getEntryMap().containsKey("batch_size")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("batch_size") .getValue());
}
public Boolean getLoadCheckpoint() {
if (!getConfiguration().getEntryMap().containsKey("load_checkpoint")) {
return null;
}
return (Boolean) getConfiguration().getEntry("load_checkpoint").getValue().getValue();
}
public Boolean getNormalize() {
if (!getConfiguration().getEntryMap().containsKey("normalize")) {
return null;
}
return (Boolean) getConfiguration().getEntry("normalize").getValue().getValue();
}
public String getContext() {
if (!getConfiguration().getEntryMap().containsKey("context")) {
return null;
}
return getConfiguration().getEntry("context").getValue().toString();
}
public String getEvalMetric() {
if (!getConfiguration().getEntryMap().containsKey("eval_metric")) {
return null;
}
return getConfiguration().getEntry("eval_metric").getValue().toString();
}
public String getOptimizerName() {
if (getConfiguration().getOptimizer() == null) {
return null;
}
return getConfiguration().getOptimizer().getName();
}
public Map<String, String> getOptimizerParams() {
// get classes for single enum values
List<Class> lrPolicyClasses = new ArrayList<>();
for (LRPolicy enum_value: LRPolicy.values()) {
lrPolicyClasses.add(enum_value.getClass());
}
Map<String, String> mapToStrings = new HashMap<>();
Map<String, OptimizerParamSymbol> optimizerParams = getConfiguration().getOptimizer().getOptimizerParamMap();
for (Map.Entry<String, OptimizerParamSymbol> entry : optimizerParams.entrySet()) {
String paramName = entry.getKey();
String valueAsString = entry.getValue().toString();
Class realClass = entry.getValue().getValue().getValue().getClass();
if (realClass == Boolean.class) {
valueAsString = (Boolean) entry.getValue().getValue().getValue() ? "True" : "False";
}
else if (lrPolicyClasses.contains(realClass)) {
valueAsString = "'" + valueAsString + "'";
}
mapToStrings.put(paramName, valueAsString);
}
return mapToStrings;
}
}
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>. * License along with this project. If not, see <http://www.gnu.org/licenses/>.
* ******************************************************************************* * *******************************************************************************
*/ */
package de.monticore.lang.monticar.cnnarch.generator; package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.*; import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.Convolution; import de.monticore.lang.monticar.cnnarch.predefined.Convolution;
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>. * License along with this project. If not, see <http://www.gnu.org/licenses/>.
* ******************************************************************************* * *******************************************************************************
*/ */
package de.monticore.lang.monticar.cnnarch.generator; package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
//can be removed //can be removed
public enum Target { public enum Target {
...@@ -31,26 +31,7 @@ public enum Target { ...@@ -31,26 +31,7 @@ public enum Target {
CPP{ CPP{
@Override @Override
public String toString() { public String toString() {
return ".cpp"; return ".h";
} }
}; };
public static Target fromString(String target){
switch (target.toLowerCase()){
case "python":
return PYTHON;
case "py":
return PYTHON;
case "cpp":
return CPP;