Commit d83c8f53 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'added-trainer' into 'master'

Added trainer

See merge request !6
parents 29d9750a 23eb5106
Pipeline #69519 passed with stages
in 1 minute and 56 seconds
......@@ -8,14 +8,15 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-mxnet-generator</artifactId>
<version>0.2.1-SNAPSHOT</version>
<version>0.2.2-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.2.1-SNAPSHOT</CNNArch.version>
<CNNArch.version>0.2.3-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.2.4-SNAPSHOT</CNNTrain.version>
<!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version>
......@@ -70,6 +71,20 @@
<scope>provided</scope>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>${CNNTrain.version}</version>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>${CNNTrain.version}</version>
<classifier>${grammars.classifier}</classifier>
<scope>provided</scope>
</dependency>
<!-- .. Test Libraries ............................................... -->
<dependency>
......@@ -127,7 +142,7 @@
<configuration>
<archive>
<manifest>
<mainClass>de.monticore.lang.monticar.cnnarch.generator.CNNArchGeneratorCli</mainClass>
<mainClass>de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArch2MxNetCli</mainClass>
</manifest>
</archive>
<descriptorRefs>
......
......@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.generator;
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
......
......@@ -18,13 +18,15 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.generator;
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.symboltable.GlobalScope;
import de.monticore.symboltable.Scope;
import de.se_rwth.commons.logging.Log;
......@@ -33,15 +35,13 @@ import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.*;
public class CNNArchGenerator {
public class CNNArch2MxNet implements CNNArchGenerator {
private String generationTargetPath;
public CNNArchGenerator() {
public CNNArch2MxNet() {
setGenerationTargetPath("./target/generated-sources-cnnarch/");
}
......@@ -79,6 +79,24 @@ public class CNNArchGenerator {
}
}
@Override
public Map<String, String> generateTrainer(List<ConfigurationSymbol> configurations, List<String> instanceNames, String mainComponentName) {
int numberOfNetworks = configurations.size();
if (configurations.size() != instanceNames.size()){
throw new IllegalStateException(
"The number of configurations and the number of instances for generation of the CNNTrainer is not equal. " +
"This should have been checked previously.");
}
List<ConfigurationData> configDataList = new ArrayList<>();
for(int i = 0; i < numberOfNetworks; i++){
configDataList.add(new ConfigurationData(configurations.get(i), instanceNames.get(i)));
}
Map<String, Object> ftlContext = Collections.singletonMap("configurations", configDataList);
return Collections.singletonMap(
"CNNTrainer_" + mainComponentName + ".py",
TemplateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl"));
}
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public Map<String, String> generateStrings(ArchitectureSymbol architecture){
Map<String, String> fileContentMap = new HashMap<>();
......@@ -104,21 +122,19 @@ public class CNNArchGenerator {
private void checkValidGeneration(ArchitectureSymbol architecture){
if (architecture.getInputs().size() > 1){
Log.warn("This cnn architecture has multiple inputs, " +
"which is currently not supported by the generator. " +
"The generated code will not work correctly."
Log.error("This cnn architecture has multiple inputs, " +
"which is currently not supported by the mxnetgenerator. "
, architecture.getSourcePosition());
}
if (architecture.getOutputs().size() > 1){
Log.warn("This cnn architecture has multiple outputs, " +
"which is currently not supported by the generator. " +
"The generated code will not work correctly."
Log.error("This cnn architecture has multiple outputs, " +
"which is currently not supported by the mxnetgenerator. "
, architecture.getSourcePosition());
}
if (architecture.getOutputs().get(0).getDefinition().getType().getWidth() != 1 ||
architecture.getOutputs().get(0).getDefinition().getType().getHeight() != 1){
Log.error("This cnn architecture has a multi-dimensional output, " +
"which is currently not supported by the generator."
"which is currently not supported by the mxnetgenerator."
, architecture.getSourcePosition());
}
}
......@@ -143,5 +159,4 @@ public class CNNArchGenerator {
writer.close();
}
}
}
......@@ -18,14 +18,14 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.generator;
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import org.apache.commons.cli.*;
import java.nio.file.Path;
import java.nio.file.Paths;
public class CNNArchGeneratorCli {
public class CNNArch2MxNetCli {
public static final Option OPTION_MODELS_PATH = Option.builder("m")
.longOpt("models-dir")
......@@ -48,7 +48,7 @@ public class CNNArchGeneratorCli {
.required(false)
.build();
private CNNArchGeneratorCli() {
private CNNArch2MxNetCli() {
}
public static void main(String[] args) {
......@@ -84,7 +84,7 @@ public class CNNArchGeneratorCli {
Path modelsDirPath = Paths.get(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt()));
String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt());
String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt());
CNNArchGenerator generator = new CNNArchGenerator();
CNNArch2MxNet generator = new CNNArch2MxNet();
if (outputPath != null){
generator.setGenerationTargetPath(outputPath);
}
......
......@@ -18,17 +18,12 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.generator;
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.Sigmoid;
import de.monticore.lang.monticar.cnnarch.predefined.Softmax;
import de.se_rwth.commons.logging.Log;
import freemarker.template.Configuration;
import freemarker.template.Template;
import freemarker.template.TemplateException;
import java.io.IOException;
import java.io.StringWriter;
import java.io.Writer;
import java.util.*;
......@@ -41,14 +36,15 @@ public class CNNArchTemplateController {
public static final String ELEMENT_DATA_KEY = "element";
private LayerNameCreator nameManager;
private Configuration freemarkerConfig = TemplateConfiguration.get();
private ArchitectureSymbol architecture;
//temporary attributes. They are set after calling process()
private Writer writer;
private String mainTemplateNameWithoutEnding;
private Target targetLanguage;
private ArchitectureElementData dataElement;
public CNNArchTemplateController(ArchitectureSymbol architecture) {
setArchitecture(architecture);
}
......@@ -57,14 +53,6 @@ public class CNNArchTemplateController {
return mainTemplateNameWithoutEnding + "_" + getFullArchitectureName();
}
public Target getTargetLanguage(){
return targetLanguage;
}
public void setTargetLanguage(Target targetLanguage) {
this.targetLanguage = targetLanguage;
}
public ArchitectureElementData getCurrentElement() {
return dataElement;
}
......@@ -137,25 +125,10 @@ public class CNNArchTemplateController {
public void include(String relativePath, String templateWithoutFileEnding, Writer writer){
String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING;
try {
Template template = freemarkerConfig.getTemplate(templatePath);
Map<String, Object> ftlContext = new HashMap<>();
ftlContext.put(TEMPLATE_CONTROLLER_KEY, this);
ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement());
this.writer = writer;
template.process(ftlContext, writer);
this.writer = null;
}
catch (IOException e) {
Log.error("Freemarker could not find template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
}
catch (TemplateException e){
Log.error("An exception occured in template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
}
Map<String, Object> ftlContext = new HashMap<>();
ftlContext.put(TEMPLATE_CONTROLLER_KEY, this);
ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement());
TemplateConfiguration.processTemplate(ftlContext, templatePath, writer);
}
public void include(IOSymbol ioElement, Writer writer){
......@@ -229,18 +202,16 @@ public class CNNArchTemplateController {
StringWriter writer = new StringWriter();
this.mainTemplateNameWithoutEnding = templateNameWithoutEnding;
this.targetLanguage = targetLanguage;
include("", templateNameWithoutEnding, writer);
this.writer = writer;
include("", templateNameWithoutEnding, writer);
String fileEnding = targetLanguage.toString();
if (targetLanguage == Target.CPP){
fileEnding = ".h";
}
String fileName = getFileNameWithoutEnding() + fileEnding;
Map.Entry<String,String> fileContent = new AbstractMap.SimpleEntry<>(fileName, writer.toString());
this.mainTemplateNameWithoutEnding = null;
this.targetLanguage = null;
this.writer = null;
return fileContent;
}
......
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class ConfigurationData {
ConfigurationSymbol configuration;
String instanceName;
public ConfigurationData(ConfigurationSymbol configuration, String instanceName) {
this.configuration = configuration;
this.instanceName = instanceName;
}
public ConfigurationSymbol getConfiguration() {
return configuration;
}
public String getInstanceName() {
return instanceName;
}
public String getNumEpoch() {
if (!getConfiguration().getEntryMap().containsKey("num_epoch")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("num_epoch").getValue());
}
public String getBatchSize() {
if (!getConfiguration().getEntryMap().containsKey("batch_size")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("batch_size") .getValue());
}
public Boolean getLoadCheckpoint() {
if (!getConfiguration().getEntryMap().containsKey("load_checkpoint")) {
return null;
}
return (Boolean) getConfiguration().getEntry("load_checkpoint").getValue().getValue();
}
public Boolean getNormalize() {
if (!getConfiguration().getEntryMap().containsKey("normalize")) {
return null;
}
return (Boolean) getConfiguration().getEntry("normalize").getValue().getValue();
}
public String getContext() {
if (!getConfiguration().getEntryMap().containsKey("context")) {
return null;
}
return getConfiguration().getEntry("context").getValue().toString();
}
public String getEvalMetric() {
if (!getConfiguration().getEntryMap().containsKey("eval_metric")) {
return null;
}
return getConfiguration().getEntry("eval_metric").getValue().toString();
}
public String getOptimizerName() {
if (getConfiguration().getOptimizer() == null) {
return null;
}
return getConfiguration().getOptimizer().getName();
}
public Map<String, String> getOptimizerParams() {
// get classes for single enum values
List<Class> lrPolicyClasses = new ArrayList<>();
for (LRPolicy enum_value: LRPolicy.values()) {
lrPolicyClasses.add(enum_value.getClass());
}
Map<String, String> mapToStrings = new HashMap<>();
Map<String, OptimizerParamSymbol> optimizerParams = getConfiguration().getOptimizer().getOptimizerParamMap();
for (Map.Entry<String, OptimizerParamSymbol> entry : optimizerParams.entrySet()) {
String paramName = entry.getKey();
String valueAsString = entry.getValue().toString();
Class realClass = entry.getValue().getValue().getValue().getClass();
if (realClass == Boolean.class) {
valueAsString = (Boolean) entry.getValue().getValue().getValue() ? "True" : "False";
}
else if (lrPolicyClasses.contains(realClass)) {
valueAsString = "'" + valueAsString + "'";
}
mapToStrings.put(paramName, valueAsString);
}
return mapToStrings;
}
}
......@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.generator;
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.Convolution;
......
......@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.generator;
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
//can be removed
public enum Target {
......@@ -31,26 +31,7 @@ public enum Target {
CPP{
@Override
public String toString() {
return ".cpp";
return ".h";
}
};
public static Target fromString(String target){
switch (target.toLowerCase()){
case "python":
return PYTHON;
case "py":
return PYTHON;
case "cpp":
return CPP;
case "c++":
return CPP;
default:
throw new IllegalArgumentException();
}
}
}
......@@ -18,11 +18,19 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.generator;
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.se_rwth.commons.logging.Log;
import freemarker.template.Configuration;
import freemarker.template.Template;
import freemarker.template.TemplateException;
import freemarker.template.TemplateExceptionHandler;
import java.io.IOException;
import java.io.StringWriter;
import java.io.Writer;
import java.util.Map;
public class TemplateConfiguration {
private static TemplateConfiguration instance;
......@@ -30,7 +38,7 @@ public class TemplateConfiguration {
private TemplateConfiguration() {
configuration = new Configuration(Configuration.VERSION_2_3_23);
configuration.setClassForTemplateLoading(TemplateConfiguration.class, "/templates/");
configuration.setClassForTemplateLoading(TemplateConfiguration.class, "/templates/mxnet/");
configuration.setDefaultEncoding("UTF-8");
configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER);
}
......@@ -46,4 +54,25 @@ public class TemplateConfiguration {
return instance.getConfiguration();
}
public static void processTemplate(Map<String, Object> ftlContext, String templatePath, Writer writer){
try{
Template template = TemplateConfiguration.get().getTemplate(templatePath);
template.process(ftlContext, writer);
}
catch (IOException e) {
Log.error("Freemarker could not find template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
}
catch (TemplateException e){
Log.error("An exception occured in template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
}
}
public static String processTemplate(Map<String, Object> ftlContext, String templatePath){
StringWriter writer = new StringWriter();
processTemplate(ftlContext, templatePath, writer);
return writer.toString();
}
}
......@@ -105,7 +105,7 @@ class ${tc.fileNameWithoutEnding}:
sys.exit(1)
def train(self, batch_size,
def train(self, batch_size=64,
num_epoch=10,
eval_metric='acc',
optimizer='adam',
......
import logging
import mxnet as mx
<#list configurations as config>
import CNNCreator_${config.instanceName}
</#list>
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
logger.addHandler(handler)
<#list configurations as config>
${config.instanceName} = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
${config.instanceName}.train(
<#if (config.batchSize)??>
batch_size=${config.batchSize},
</#if>
<#if (config.numEpoch)??>
num_epoch=${config.numEpoch},
</#if>
<#if (config.loadCheckpoint)??>
load_checkpoint=${config.loadCheckpoint?string("True","False")},
</#if>
<#if (config.context)??>
context='${config.context}',
</#if>
<#if (config.normalize)??>
normalize=${config.normalize?string("True","False")},
</#if>
<#if (config.evalMetric)??>
eval_metric='${config.evalMetric}',
</#if>
<#if (config.configuration.optimizer)??>
optimizer='${config.optimizerName}',
optimizer_params={
<#list config.optimizerParams?keys as param>
'${param}': ${config.optimizerParams[param]}<#sep>,
</#list>
}
</#if>
)
</#list>
\ No newline at end of file
${element.name} = mx.symbol.Pooling(data=${element.inputs[0]},
global_pool=True,
kernel=(1,1),
pool_type=${element.poolType},
pool_type="${element.poolType}",
name="${element.name}")
<#include "OutputShape.ftl">
\ No newline at end of file
......@@ -8,7 +8,7 @@
</#if>
${element.name} = mx.symbol.Pooling(data=${input},
kernel=(${tc.join(element.kernel, ",")}),
pool_type=${element.poolType},
pool_type="${element.poolType}",
stride=(${tc.join(element.stride, ",")}),
name="${element.name}")
<#include "OutputShape.ftl">
\ No newline at end of file
......@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch;
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.ModelingLanguageFamily;
import de.monticore.io.paths.ModelPath;
......
......@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch;
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch._parser.CNNArchParser;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
......
......@@ -28,4 +28,4 @@ architecture VGG16(img_height=224, img_width=224, img_channels=3, classes=1000){
FullyConnected(units=classes) ->
Softmax() ->
predictions
}
}
\ No newline at end of file
......@@ -105,7 +105,7 @@ class CNNCreator_Alexnet:
sys.exit(1)