Commit 1802f7d0 authored by Svetlana Pavlitskaya's avatar Svetlana Pavlitskaya

A new generator class for CNNTrain language impelemnting corresponding interface

parent cfb99d88
Pipeline #71338 passed with stages
in 4 minutes and 5 seconds
......@@ -8,15 +8,15 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-mxnet-generator</artifactId>
<version>0.2.4-SNAPSHOT</version>
<version>0.2.5-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.2.5-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.2.4-SNAPSHOT</CNNTrain.version>
<CNNArch.version>0.2.6-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.2.5-SNAPSHOT</CNNTrain.version>
<embedded-montiarc-math-generator>0.0.25-SNAPSHOT</embedded-montiarc-math-generator>
<!-- .. Libraries .................................................. -->
......
......@@ -26,19 +26,19 @@ import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cmake.CMakeConfig;
import de.monticore.lang.monticar.generator.cmake.CMakeFindModule;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.symboltable.GlobalScope;
import de.monticore.symboltable.Scope;
import de.se_rwth.commons.logging.Log;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Path;
import java.util.*;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
public class CNNArch2MxNet implements CNNArchGenerator {
......@@ -87,24 +87,6 @@ public class CNNArch2MxNet implements CNNArchGenerator {
}
}
@Override
public Map<String, String> generateTrainer(List<ConfigurationSymbol> configurations, List<String> instanceNames, String mainComponentName) {
int numberOfNetworks = configurations.size();
if (configurations.size() != instanceNames.size()){
throw new IllegalStateException(
"The number of configurations and the number of instances for generation of the CNNTrainer is not equal. " +
"This should have been checked previously.");
}
List<ConfigurationData> configDataList = new ArrayList<>();
for(int i = 0; i < numberOfNetworks; i++){
configDataList.add(new ConfigurationData(configurations.get(i), instanceNames.get(i)));
}
Map<String, Object> ftlContext = Collections.singletonMap("configurations", configDataList);
return Collections.singletonMap(
"CNNTrainer_" + mainComponentName + ".py",
TemplateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl"));
}
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public Map<String, String> generateStrings(ArchitectureSymbol architecture){
Map<String, String> fileContentMap = new HashMap<>();
......@@ -163,19 +145,10 @@ public class CNNArch2MxNet implements CNNArchGenerator {
}
private void generateFromFilecontentsMap(Map<String, String> fileContentMap) throws IOException {
GeneratorCPP genCPP = new GeneratorCPP();
genCPP.setGenerationTargetPath(getGenerationTargetPath());
for (String fileName : fileContentMap.keySet()){
File f = new File(getGenerationTargetPath() + fileName);
Log.info(f.getName(), "FileCreation:");
if (!f.exists()) {
f.getParentFile().mkdirs();
if (!f.createNewFile()) {
Log.error("File could not be created");
}
}
FileWriter writer = new FileWriter(f);
writer.write(fileContentMap.get(fileName));
writer.close();
genCPP.generateFile(new FileContent(fileContentMap.get(fileName), fileName));
}
}
......@@ -186,7 +159,7 @@ public class CNNArch2MxNet implements CNNArchGenerator {
CMakeConfig cMakeConfig = new CMakeConfig(rootModelName);
cMakeConfig.addModuleDependency(new CMakeFindModule("Armadillo", true));
cMakeConfig.addCMakeCommandEnd("set(LIBS ${LIBS} mxnet)");
cMakeConfig.addCMakeCommand("set(LIBS ${LIBS} mxnet)");
Map<String,String> fileContentMap = new HashMap<>();
for (FileContent fileContent : cMakeConfig.generateCMakeFiles()){
......
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnntrain.CNNTrainGenerator;
import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.symboltable.GlobalScope;
import de.se_rwth.commons.logging.Log;
import java.io.IOException;
import java.nio.file.Path;
import java.util.*;
public class CNNTrain2MxNet implements CNNTrainGenerator {
private String generationTargetPath;
private String instanceName;
public CNNTrain2MxNet() {
setGenerationTargetPath("./target/generated-sources-cnnarch/");
}
public String getInstanceName() {
String parsedInstanceName = this.instanceName.replace('.', '_').replace('[', '_').replace(']', '_');
parsedInstanceName = parsedInstanceName.substring(0, 1).toLowerCase() + parsedInstanceName.substring(1);
return parsedInstanceName;
}
public void setInstanceName(String instanceName) {
this.instanceName = instanceName;
}
public String getGenerationTargetPath() {
if (generationTargetPath.charAt(generationTargetPath.length() - 1) != '/') {
this.generationTargetPath = generationTargetPath + "/";
}
return generationTargetPath;
}
public void setGenerationTargetPath(String generationTargetPath) {
this.generationTargetPath = generationTargetPath;
}
public ConfigurationSymbol getConfigurationSymbol(Path modelsDirPath, String rootModelName) {
final ModelPath mp = new ModelPath(modelsDirPath);
GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage());
Optional<CNNTrainCompilationUnitSymbol> compilationUnit = scope.resolve(rootModelName, CNNTrainCompilationUnitSymbol.KIND);
if (!compilationUnit.isPresent()) {
Log.error("could not resolve training configuration " + rootModelName);
System.exit(1);
}
setInstanceName(compilationUnit.get().getFullName());
CNNTrainCocos.checkAll(compilationUnit.get());
return compilationUnit.get().getConfiguration();
}
public void generate(Path modelsDirPath, String rootModelName) {
ConfigurationSymbol configuration = getConfigurationSymbol(modelsDirPath, rootModelName);
Map<String, String> fileContents = generateStrings(configuration);
GeneratorCPP genCPP = new GeneratorCPP();
genCPP.setGenerationTargetPath(getGenerationTargetPath());
try {
for (String fileName : fileContents.keySet()){
genCPP.generateFile(new FileContent(fileContents.get(fileName), fileName));
}
} catch (IOException e) {
e.printStackTrace();
}
}
public Map<String, String> generateStrings(ConfigurationSymbol configuration) {
ConfigurationData configData = new ConfigurationData(configuration, getInstanceName());
List<ConfigurationData> configDataList = new ArrayList<>();
configDataList.add(configData);
Map<String, Object> ftlContext = Collections.singletonMap("configurations", configDataList);
String templateContent = TemplateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
return Collections.singletonMap("CNNTrainer_" + getInstanceName() + ".py", templateContent);
}
}
\ No newline at end of file
......@@ -60,15 +60,6 @@ public class AbstractSymtabTest {
return scope;
}
/* protected static ASTCNNArchCompilationUnit getAstNode(String modelPath, String model) {
Scope symTab = createSymTab(MODEL_PATH + modelPath);
CNNArchCompilationUnitSymbol comp = symTab.<CNNArchCompilationUnitSymbol> resolve(
model, CNNArchCompilationUnitSymbol.KIND).orElse(null);
assertNotNull("Could not resolve model " + model, comp);
return (ASTCNNArchCompilationUnit) comp.getAstNode().get();
}*/
protected static CNNArchCompilationUnitSymbol getCompilationUnitSymbol(String modelPath, String model) {
Scope symTab = createSymTab(MODEL_PATH + modelPath);
CNNArchCompilationUnitSymbol comp = symTab.<CNNArchCompilationUnitSymbol> resolve(
......
......@@ -20,19 +20,13 @@
*/
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.symboltable.GlobalScope;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
import org.junit.Before;
import org.junit.Test;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
......@@ -121,143 +115,50 @@ public class GenerationTest extends AbstractSymtabTest{
assertTrue(Log.getFindings().size() == 3);
}
@Test
public void testCNNTrainerGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
List<ConfigurationSymbol> configurations = new ArrayList<>();
List<String> instanceNames = Arrays.asList("main_net1", "main_net2");
final ModelPath mp = new ModelPath(Paths.get("src/test/resources/valid_tests"));
GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage());
CNNTrainCompilationUnitSymbol compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("Network1", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("Network2", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
CNNArch2MxNet generator = new CNNArch2MxNet();
Map<String,String> trainerMap = generator.generateTrainer(configurations, instanceNames, "main");
for (String fileName : trainerMap.keySet()){
FileWriter writer = new FileWriter(generator.getGenerationTargetPath() + fileName);
writer.write(trainerMap.get(fileName));
writer.close();
}
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNTrainer_main.py"));
}
@Test
public void testFullCfgGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
List<ConfigurationSymbol> configurations = new ArrayList<>();
List<String> instanceName = Arrays.asList("main_net1", "main_net2");
final ModelPath mp = new ModelPath(Paths.get("src/test/resources/valid_tests"));
GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage());
CNNTrainCompilationUnitSymbol compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("FullConfig", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("FullConfig2", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
CNNArch2MxNet generator = new CNNArch2MxNet();
Map<String,String> trainerMap = generator.generateTrainer(configurations, instanceName, "mainFull");
for (String fileName : trainerMap.keySet()){
FileWriter writer = new FileWriter(generator.getGenerationTargetPath() + fileName);
writer.write(trainerMap.get(fileName));
writer.close();
}
String sourcePath = "src/test/resources/valid_tests";
CNNTrain2MxNet trainGenerator = new CNNTrain2MxNet();
trainGenerator.generate(Paths.get(sourcePath), "FullConfig");
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNTrainer_mainFull.py"));
"CNNTrainer_fullConfig.py"));
}
@Test
public void testSimpleCfgGeneration() throws IOException {
Log.getFindings().clear();
List<ConfigurationSymbol> configurations = new ArrayList<>();
List<String> instanceName = Arrays.asList("main_net1", "main_net2");
final ModelPath mp = new ModelPath(Paths.get("src/test/resources/valid_tests"));
GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage());
CNNTrainCompilationUnitSymbol compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("SimpleConfig1", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
Path modelPath = Paths.get("src/test/resources/valid_tests");
CNNTrain2MxNet trainGenerator = new CNNTrain2MxNet();
compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("SimpleConfig2", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
CNNArch2MxNet generator = new CNNArch2MxNet();
Map<String,String> trainerMap = generator.generateTrainer(configurations, instanceName, "mainSimple");
for (String fileName : trainerMap.keySet()){
FileWriter writer = new FileWriter(generator.getGenerationTargetPath() + fileName);
writer.write(trainerMap.get(fileName));
writer.close();
}
trainGenerator.generate(modelPath, "SimpleConfig");
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNTrainer_mainSimple.py"));
"CNNTrainer_simpleConfig.py"));
}
@Test
public void testEmptyCfgGeneration() throws IOException {
Log.getFindings().clear();
List<ConfigurationSymbol> configurations = new ArrayList<>();
List<String> instanceName = Arrays.asList("main_net1");
final ModelPath mp = new ModelPath(Paths.get("src/test/resources/valid_tests"));
GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage());
CNNTrainCompilationUnitSymbol compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("EmptyConfig", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
CNNArch2MxNet generator = new CNNArch2MxNet();
Map<String,String> trainerMap = generator.generateTrainer(configurations, instanceName, "mainEmpty");
for (String fileName : trainerMap.keySet()){
FileWriter writer = new FileWriter(generator.getGenerationTargetPath() + fileName);
writer.write(trainerMap.get(fileName));
writer.close();
}
Path modelPath = Paths.get("src/test/resources/valid_tests");
CNNTrain2MxNet trainGenerator = new CNNTrain2MxNet();
trainGenerator.generate(modelPath, "EmptyConfig");
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNTrainer_mainEmpty.py"));
"CNNTrainer_emptyConfig.py"));
}
......
......@@ -12,6 +12,7 @@ set(INCLUDE_DIRS ${INCLUDE_DIRS} ${Armadillo_INCLUDE_DIRS})
set(LIBS ${LIBS} ${Armadillo_LIBRARIES})
# additional commands
set(LIBS ${LIBS} mxnet)
# create static library
include_directories(${INCLUDE_DIRS})
......@@ -24,4 +25,3 @@ set_target_properties(alexnet PROPERTIES LINKER_LANGUAGE CXX)
export(TARGETS alexnet FILE alexnet.cmake)
# additional commands end
set(LIBS ${LIBS} mxnet)
import logging
import mxnet as mx
import CNNCreator_main_net1
import CNNCreator_emptyConfig
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
......@@ -8,6 +8,6 @@ if __name__ == "__main__":
handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
logger.addHandler(handler)
main_net1 = CNNCreator_main_net1.CNNCreator_main_net1()
main_net1.train(
emptyConfig = CNNCreator_emptyConfig.CNNCreator_emptyConfig()
emptyConfig.train(
)
import logging
import mxnet as mx
import CNNCreator_main_net1
import CNNCreator_main_net2
import CNNCreator_fullConfig
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
......@@ -9,8 +8,8 @@ if __name__ == "__main__":
handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
logger.addHandler(handler)
main_net1 = CNNCreator_main_net1.CNNCreator_main_net1()
main_net1.train(
fullConfig = CNNCreator_fullConfig.CNNCreator_fullConfig()
fullConfig.train(
batch_size=100,
num_epoch=5,
load_checkpoint=True,
......@@ -33,25 +32,3 @@ if __name__ == "__main__":
'learning_rate': 0.001,
'step_size': 1000}
)
main_net2 = CNNCreator_main_net2.CNNCreator_main_net2()
main_net2.train(
batch_size=100,
num_epoch=10,
load_checkpoint=False,
context='gpu',
normalize=False,
eval_metric='topKAccuracy',
optimizer='adam',
optimizer_params={
'epsilon': 1.0E-6,
'weight_decay': 0.01,
'rescale_grad': 1.1,
'beta1': 0.9,
'clip_gradient': 10.0,
'beta2': 0.9,
'learning_rate_minimum': 0.001,
'learning_rate_policy': 'exp',
'learning_rate': 0.001,
'learning_rate_decay': 0.9,
'step_size': 1000}
)
import logging
import mxnet as mx
import CNNCreator_main_net1
import CNNCreator_main_net2
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
logger.addHandler(handler)
main_net1 = CNNCreator_main_net1.CNNCreator_main_net1()
main_net1.train(
batch_size=64,
num_epoch=10,
load_checkpoint=False,
context='gpu',
normalize=True,
optimizer='adam',
optimizer_params={
'weight_decay': 1.0E-4,
'learning_rate': 0.01,
'learning_rate_decay': 0.8,
'step_size': 1000}
)
main_net2 = CNNCreator_main_net2.CNNCreator_main_net2()
main_net2.train(
batch_size=32,
num_epoch=10,
load_checkpoint=False,
context='gpu',
normalize=True,
optimizer='adam',
optimizer_params={
'weight_decay': 1.0E-4,
'learning_rate': 0.01,
'learning_rate_decay': 0.8,
'step_size': 1000}
)
import logging
import mxnet as mx
import CNNCreator_main_net1
import CNNCreator_main_net2
import CNNCreator_simpleConfig
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
......@@ -9,19 +8,11 @@ if __name__ == "__main__":
handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
logger.addHandler(handler)
main_net1 = CNNCreator_main_net1.CNNCreator_main_net1()
main_net1.train(
simpleConfig = CNNCreator_simpleConfig.CNNCreator_simpleConfig()
simpleConfig.train(
batch_size=100,
num_epoch=50,
optimizer='adam',
optimizer_params={
'learning_rate': 0.001}
)
main_net2 = CNNCreator_main_net2.CNNCreator_main_net2()
main_net2.train(
batch_size=100,
num_epoch=5,
optimizer='sgd',
optimizer_params={
'learning_rate': 0.1}
)
configuration FullConfig2{
num_epoch : 10
batch_size : 100
load_checkpoint : false
context : gpu
eval_metric : top_k_accuracy
normalize : false
optimizer : adam{
learning_rate : 0.001
learning_rate_minimum : 0.001
weight_decay : 0.01
learning_rate_decay : 0.9
learning_rate_policy : exp
step_size : 1000
rescale_grad : 1.1
clip_gradient : 10
beta1 : 0.9
beta2 : 0.9
epsilon : 0.000001
}
}
configuration Network1{
num_epoch:10
batch_size:64
normalize:true
context:gpu
load_checkpoint:false
optimizer:adam{
learning_rate:0.01
learning_rate_decay:0.8
step_size:1000
weight_decay:0.0001
}
}
configuration Network2{
num_epoch:10
batch_size:32
normalize:true
context:gpu
load_checkpoint:false
optimizer:adam{
learning_rate:0.01
learning_rate_decay:0.8
step_size:1000
weight_decay:0.0001
}
}
configuration SimpleConfig1{
configuration SimpleConfig{
num_epoch : 50
batch_size : 100
optimizer : adam{
......
configuration SimpleConfig2{
num_epoch:5
batch_size:100
optimizer:sgd{
learning_rate:0.1
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment