Commit 9e00e9b6 authored by Svetlana Pavlitskaya's avatar Svetlana Pavlitskaya Committed by Evgeny Kusmenko

Using updated CNNArchMXNet, refactoring of code related to CNNTrain

parent 7111056a
......@@ -8,7 +8,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>embedded-montiarc-emadl-generator</artifactId>
<version>0.2.3-SNAPSHOT</version>
<version>0.2.4-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -16,8 +16,8 @@
<!-- .. SE-Libraries .................................................. -->
<emadl.version>0.2.2-SNAPSHOT</emadl.version>
<CNNTrain.version>0.2.4-SNAPSHOT</CNNTrain.version>
<cnnarch-mxnet-generator.version>0.2.4-SNAPSHOT</cnnarch-mxnet-generator.version>
<CNNTrain.version>0.2.5-SNAPSHOT</CNNTrain.version>
<cnnarch-mxnet-generator.version>0.2.5-SNAPSHOT</cnnarch-mxnet-generator.version>
<cnnarch-caffe2-generator.version>0.2.2-SNAPSHOT</cnnarch-caffe2-generator.version>
<embedded-montiarc-math-generator>0.0.25-SNAPSHOT</embedded-montiarc-math-generator>
......
......@@ -4,24 +4,35 @@ package de.monticore.lang.monticar.emadl.generator;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArch2MxNet;
import de.monticore.lang.monticar.cnnarch.caffe2generator.CNNArch2Caffe2;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNTrain2MxNet;
import de.monticore.lang.monticar.cnntrain.CNNTrainGenerator;
import java.util.Optional;
public enum Backend {
MXNET{
@Override
public CNNArchGenerator getGenerator() {
public CNNArchGenerator getCNNArchGenerator() {
return new CNNArch2MxNet();
}
@Override
public CNNTrainGenerator getCNNTrainGenerator() {
return new CNNTrain2MxNet();
}
},
CAFFE2{
@Override
public CNNArchGenerator getGenerator() {
public CNNArchGenerator getCNNArchGenerator() {
return new CNNArch2Caffe2();
}
@Override
public CNNTrainGenerator getCNNTrainGenerator() {
return null;
} // not implemented yet
};
public abstract CNNArchGenerator getGenerator();
public abstract CNNArchGenerator getCNNArchGenerator();
public abstract CNNTrainGenerator getCNNTrainGenerator();
public static Optional<Backend> getBackendFromString(String backend){
switch (backend){
......
......@@ -20,19 +20,16 @@
*/
package de.monticore.lang.monticar.emadl.generator;
import com.google.common.base.Charsets;
import com.google.common.base.Joiner;
import com.google.common.base.Splitter;
import com.google.common.base.Charsets;
import com.google.common.io.Resources;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ComponentSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ExpandedComponentInstanceSymbol;
import de.monticore.lang.math._symboltable.MathStatementsSymbol;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage;
import de.monticore.lang.monticar.cnntrain.CNNTrainGenerator;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.emadl._cocos.EMADLCocos;
import de.monticore.lang.monticar.generator.FileContent;
......@@ -42,13 +39,12 @@ import de.monticore.lang.monticar.generator.cpp.SimulatorIntegrationHelper;
import de.monticore.lang.monticar.generator.cpp.TypesGeneratorCPP;
import de.monticore.lang.monticar.generator.cpp.converter.TypeConverter;
import de.monticore.lang.tagging._symboltable.TaggingResolver;
import de.monticore.symboltable.GlobalScope;
import de.monticore.symboltable.Scope;
import de.se_rwth.commons.Splitters;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
import java.io.*;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
......@@ -60,6 +56,7 @@ public class EMADLGenerator {
private GeneratorCPP emamGen;
private CNNArchGenerator cnnArchGenerator;
private CNNTrainGenerator cnnTrainGenerator;
private String modelsPath;
......@@ -68,7 +65,8 @@ public class EMADLGenerator {
emamGen = new GeneratorCPP();
emamGen.useArmadilloBackend();
emamGen.setGenerationTargetPath("./target/generated-sources-emadl/");
cnnArchGenerator = backend.getGenerator();
cnnArchGenerator = backend.getCNNArchGenerator();
cnnTrainGenerator = backend.getCNNTrainGenerator();
}
public String getModelsPath() {
......@@ -265,70 +263,50 @@ public class EMADLGenerator {
}
public List<FileContent> generateCNNTrainer(Set<ExpandedComponentInstanceSymbol> allInstances, String mainComponentName) {
List<String> cnnInstanceNames = new ArrayList<>();
List<ConfigurationSymbol> configurations = new ArrayList<>();
List<FileContent> fileContents = new ArrayList<>();
for (ExpandedComponentInstanceSymbol componentInstance : allInstances) {
ComponentSymbol component = componentInstance.getComponentType().getReferencedSymbol();
Optional<ArchitectureSymbol> architecture = component.getSpannedScope().resolve("", ArchitectureSymbol.KIND);
if (architecture.isPresent()) {
ConfigurationSymbol configuration = getTrainingConfiguration(mainComponentName, component, componentInstance);
configurations.add(configuration);
cnnInstanceNames.add(componentInstance.getFullName().replaceAll("\\.", "_"));
String trainConfigFilename;
String mainComponentConfigFilename = mainComponentName.replaceAll("\\.", "/");
String componentConfigFilename = component.getFullName().replaceAll("\\.", "/");
String instanceConfigFilename = component.getFullName().replaceAll("\\.", "/") + "_" + component.getName();
if (Files.exists(Paths.get( getModelsPath() + instanceConfigFilename + ".cnnt"))) {
trainConfigFilename = instanceConfigFilename;
}
else if (Files.exists(Paths.get( getModelsPath() + componentConfigFilename + ".cnnt"))){
trainConfigFilename = componentConfigFilename;
}
else if (Files.exists(Paths.get( getModelsPath() + mainComponentConfigFilename + ".cnnt"))){
trainConfigFilename = mainComponentConfigFilename;
}
else{
Log.error("Missing configuration file. " +
"Could not find a file with any of the following names (only one needed): '"
+ getModelsPath() + instanceConfigFilename + ".cnnt', '"
+ getModelsPath() + componentConfigFilename + ".cnnt', '"
+ getModelsPath() + mainComponentConfigFilename + ".cnnt'." +
" These files denote respectively the configuration for the single instance, the component or the whole system.");
return null;
}
//should be removed when CNNTrain supports packages
List<String> names = Splitter.on("/").splitToList(trainConfigFilename);
trainConfigFilename = names.get(names.size()-1);
Path modelPath = Paths.get(getModelsPath() + Joiner.on("/").join(names.subList(0,names.size()-1)));
ConfigurationSymbol configuration = cnnTrainGenerator.getConfigurationSymbol(modelPath, trainConfigFilename);
cnnTrainGenerator.setInstanceName(componentInstance.getFullName().replaceAll("\\.", "_"));
Map<String, String> fileContentMap = cnnTrainGenerator.generateStrings(configuration);
for (String fileName : fileContentMap.keySet()){
fileContents.add(new FileContent(fileContentMap.get(fileName), fileName));
}
}
}
List<FileContent> fileContents = new ArrayList<>();
Map<String, String> fileContentMap = cnnArchGenerator.generateTrainer(configurations, cnnInstanceNames, mainComponentName);
for (String fileName : fileContentMap.keySet()){
fileContents.add(new FileContent(fileContentMap.get(fileName), fileName));
}
return fileContents;
}
public ConfigurationSymbol getTrainingConfiguration(String mainComponentName, ComponentSymbol component, ExpandedComponentInstanceSymbol instance) {
String configFilename;
String mainComponentConfigFilename = mainComponentName.replaceAll("\\.", "/");
String componentConfigFilename = component.getFullName().replaceAll("\\.", "/");
String instanceConfigFilename = component.getFullName().replaceAll("\\.", "/") + "_" + instance.getName();
if (Files.exists(Paths.get( getModelsPath() + instanceConfigFilename + ".cnnt"))) {
configFilename = instanceConfigFilename;
}
else if (Files.exists(Paths.get( getModelsPath() + componentConfigFilename + ".cnnt"))){
configFilename = componentConfigFilename;
}
else if (Files.exists(Paths.get( getModelsPath() + mainComponentConfigFilename + ".cnnt"))){
configFilename = mainComponentConfigFilename;
}
else{
Log.error("Missing configuration file. " +
"Could not find a file with any of the following names (only one needed): '"
+ getModelsPath() + instanceConfigFilename + ".cnnt', '"
+ getModelsPath() + componentConfigFilename + ".cnnt', '"
+ getModelsPath() + mainComponentConfigFilename + ".cnnt'." +
" These files denote respectively the configuration for the single instance, the component or the whole system.");
return null;
}
//should be removed when CNNTrain supports packages
List<String> names = Splitter.on("/").splitToList(configFilename);
configFilename = names.get(names.size()-1);
Path modelPath = Paths.get(getModelsPath() + Joiner.on("/").join(names.subList(0,names.size()-1)));
//
//CNNTrainGenerator cnnTrainGenerator = new CNNTrainGenerator(); //No need of cnnTrainGenerator since cnnArchGenerator can also generateTrainer()
final ModelPath mp = new ModelPath(modelPath);
GlobalScope trainScope = new GlobalScope(mp, new CNNTrainLanguage());
Optional<CNNTrainCompilationUnitSymbol> compilationUnit = trainScope.resolve(configFilename, CNNTrainCompilationUnitSymbol.KIND);
if (!compilationUnit.isPresent()){
Log.error("CNNTrainCompilationUnitSymbol is empty. Could not resolve configuration " + configFilename);
System.exit(1);
}
CNNTrainCocos.checkAll(compilationUnit.get());
ConfigurationSymbol configuration = compilationUnit.get().getConfiguration();
return configuration;
}
public String readResource(final String fileName, Charset charset) {
try {
return Resources.toString(Resources.getResource(fileName), charset);
......
......@@ -59,7 +59,7 @@ public class GenerationTest extends AbstractSymtabTest {
"cifar10_cifar10Classifier_net.h",
"CNNTranslator.h",
"cifar10_cifar10Classifier_calculateClass.h",
"CNNTrainer_cifar10_Cifar10Classifier.py"));
"CNNTrainer_cifar10_cifar10Classifier_net.py"));
}
@Test
......
configuration NetworkB_net1{
configuration NetworkB{
num_epoch:10
batch_size:64
normalize:true
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment