Added class CNNTrain2Caffe2 impelemnting corresponding interface

parent 84d1feea
Pipeline #72613 passed with stages
in 4 minutes and 19 seconds
......@@ -8,15 +8,15 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-caffe2-generator</artifactId>
<version>0.2.4-SNAPSHOT</version>
<version>0.2.5-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.2.5-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.2.4-SNAPSHOT</CNNTrain.version>
<CNNArch.version>0.2.6-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.2.5-SNAPSHOT</CNNTrain.version>
<embedded-montiarc-math-generator>0.0.25-SNAPSHOT</embedded-montiarc-math-generator>
<!-- .. Libraries .................................................. -->
......
......@@ -26,19 +26,19 @@ import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cmake.CMakeConfig;
import de.monticore.lang.monticar.generator.cmake.CMakeFindModule;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.symboltable.GlobalScope;
import de.monticore.symboltable.Scope;
import de.se_rwth.commons.logging.Log;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Path;
import java.util.*;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
public class CNNArch2Caffe2 implements CNNArchGenerator{
......@@ -87,24 +87,6 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
}
}
@Override
public Map<String, String> generateTrainer(List<ConfigurationSymbol> configurations, List<String> instanceNames, String mainComponentName) {
int numberOfNetworks = configurations.size();
if (configurations.size() != instanceNames.size()){
throw new IllegalStateException(
"The number of configurations and the number of instances for generation of the CNNTrainer is not equal. " +
"This should have been checked previously.");
}
List<ConfigurationData> configDataList = new ArrayList<>();
for(int i = 0; i < numberOfNetworks; i++){
configDataList.add(new ConfigurationData(configurations.get(i), instanceNames.get(i)));
}
Map<String, Object> ftlContext = Collections.singletonMap("configurations", configDataList);
return Collections.singletonMap(
"CNNTrainer_" + mainComponentName + ".py",
TemplateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl"));
}
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public Map<String, String> generateStrings(ArchitectureSymbol architecture){
Map<String, String> fileContentMap = new HashMap<>();
......@@ -163,19 +145,10 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
}
private void generateFromFilecontentsMap(Map<String, String> fileContentMap) throws IOException {
GeneratorCPP genCPP = new GeneratorCPP();
genCPP.setGenerationTargetPath(getGenerationTargetPath());
for (String fileName : fileContentMap.keySet()){
File f = new File(getGenerationTargetPath() + fileName);
Log.info(f.getName(), "FileCreation:");
if (!f.exists()) {
f.getParentFile().mkdirs();
if (!f.createNewFile()) {
Log.error("File could not be created");
}
}
FileWriter writer = new FileWriter(f);
writer.write(fileContentMap.get(fileName));
writer.close();
genCPP.generateFile(new FileContent(fileContentMap.get(fileName), fileName));
}
}
......@@ -186,7 +159,7 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
CMakeConfig cMakeConfig = new CMakeConfig(rootModelName);
cMakeConfig.addModuleDependency(new CMakeFindModule("Armadillo", true));
cMakeConfig.addCMakeCommandEnd("set(LIBS ${LIBS} mxnet)");
cMakeConfig.addCMakeCommand("set(LIBS ${LIBS} mxnet)");
Map<String,String> fileContentMap = new HashMap<>();
for (FileContent fileContent : cMakeConfig.generateCMakeFiles()){
......
package de.monticore.lang.monticar.cnnarch.caffe2generator;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnntrain.CNNTrainGenerator;
import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.symboltable.GlobalScope;
import de.se_rwth.commons.logging.Log;
import java.io.IOException;
import java.nio.file.Path;
import java.util.*;
public class CNNTrain2Caffe2 implements CNNTrainGenerator {
private String generationTargetPath;
private String instanceName;
public CNNTrain2Caffe2() {
setGenerationTargetPath("./target/generated-sources-cnnarch/");
}
public String getInstanceName() {
String parsedInstanceName = this.instanceName.replace('.', '_').replace('[', '_').replace(']', '_');
parsedInstanceName = parsedInstanceName.substring(0, 1).toLowerCase() + parsedInstanceName.substring(1);
return parsedInstanceName;
}
public void setInstanceName(String instanceName) {
this.instanceName = instanceName;
}
public String getGenerationTargetPath() {
if (generationTargetPath.charAt(generationTargetPath.length() - 1) != '/') {
this.generationTargetPath = generationTargetPath + "/";
}
return generationTargetPath;
}
public void setGenerationTargetPath(String generationTargetPath) {
this.generationTargetPath = generationTargetPath;
}
public ConfigurationSymbol getConfigurationSymbol(Path modelsDirPath, String rootModelName) {
final ModelPath mp = new ModelPath(modelsDirPath);
GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage());
Optional<CNNTrainCompilationUnitSymbol> compilationUnit = scope.resolve(rootModelName, CNNTrainCompilationUnitSymbol.KIND);
if (!compilationUnit.isPresent()) {
Log.error("could not resolve training configuration " + rootModelName);
System.exit(1);
}
setInstanceName(compilationUnit.get().getFullName());
CNNTrainCocos.checkAll(compilationUnit.get());
return compilationUnit.get().getConfiguration();
}
public void generate(Path modelsDirPath, String rootModelName) {
ConfigurationSymbol configuration = getConfigurationSymbol(modelsDirPath, rootModelName);
Map<String, String> fileContents = generateStrings(configuration);
GeneratorCPP genCPP = new GeneratorCPP();
genCPP.setGenerationTargetPath(getGenerationTargetPath());
try {
for (String fileName : fileContents.keySet()){
genCPP.generateFile(new FileContent(fileContents.get(fileName), fileName));
}
} catch (IOException e) {
e.printStackTrace();
}
}
public Map<String, String> generateStrings(ConfigurationSymbol configuration) {
ConfigurationData configData = new ConfigurationData(configuration, getInstanceName());
List<ConfigurationData> configDataList = new ArrayList<>();
configDataList.add(configData);
Map<String, Object> ftlContext = Collections.singletonMap("configurations", configDataList);
String templateContent = TemplateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
return Collections.singletonMap("CNNTrainer_" + getInstanceName() + ".py", templateContent);
}
}
......@@ -7,37 +7,37 @@ import CNNCreator_${config.instanceName}
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log","w", encoding=None, delay="true")
handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
logger.addHandler(handler)
<#list configurations as config>
${config.instanceName} = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
${config.instanceName}.train(
<#if (config.batchSize)??>
batch_size = ${config.batchSize},
batch_size=${config.batchSize},
</#if>
<#if (config.numEpoch)??>
num_epoch = ${config.numEpoch},
num_epoch=${config.numEpoch},
</#if>
<#if (config.loadCheckpoint)??>
load_checkpoint = ${config.loadCheckpoint?string("True","False")},
load_checkpoint=${config.loadCheckpoint?string("True","False")},
</#if>
<#if (config.context)??>
context = '${config.context}',
context='${config.context}',
</#if>
<#if (config.normalize)??>
normalize = ${config.normalize?string("True","False")},
normalize=${config.normalize?string("True","False")},
</#if>
<#if (config.evalMetric)??>
eval_metric = '${config.evalMetric}',
eval_metric='${config.evalMetric}',
</#if>
<#if (config.configuration.optimizer)??>
optimizer = '${config.optimizerName}',
optimizer_params = {
optimizer='${config.optimizerName}',
optimizer_params={
<#list config.optimizerParams?keys as param>
'${param}': ${config.optimizerParams[param]}<#sep>,
</#list>
}
}
</#if>
)
</#list>
\ No newline at end of file
......@@ -60,15 +60,6 @@ public class AbstractSymtabTest {
return scope;
}
/* protected static ASTCNNArchCompilationUnit getAstNode(String modelPath, String model) {
Scope symTab = createSymTab(MODEL_PATH + modelPath);
CNNArchCompilationUnitSymbol comp = symTab.<CNNArchCompilationUnitSymbol> resolve(
model, CNNArchCompilationUnitSymbol.KIND).orElse(null);
assertNotNull("Could not resolve model " + model, comp);
return (ASTCNNArchCompilationUnit) comp.getAstNode().get();
}*/
protected static CNNArchCompilationUnitSymbol getCompilationUnitSymbol(String modelPath, String model) {
Scope symTab = createSymTab(MODEL_PATH + modelPath);
CNNArchCompilationUnitSymbol comp = symTab.<CNNArchCompilationUnitSymbol> resolve(
......
......@@ -20,19 +20,13 @@
*/
package de.monticore.lang.monticar.cnnarch.caffe2generator;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.symboltable.GlobalScope;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
import org.junit.Before;
import org.junit.Test;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
......@@ -121,143 +115,50 @@ public class GenerationTest extends AbstractSymtabTest{
assertTrue(Log.getFindings().size() == 3);
}
@Test
public void testCNNTrainerGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
List<ConfigurationSymbol> configurations = new ArrayList<>();
List<String> instanceNames = Arrays.asList("main_net1", "main_net2");
final ModelPath mp = new ModelPath(Paths.get("src/test/resources/valid_tests"));
GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage());
CNNTrainCompilationUnitSymbol compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("Network1", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("Network2", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
CNNArch2Caffe2 generator = new CNNArch2Caffe2();
Map<String,String> trainerMap = generator.generateTrainer(configurations, instanceNames, "main");
for (String fileName : trainerMap.keySet()){
FileWriter writer = new FileWriter(generator.getGenerationTargetPath() + fileName);
writer.write(trainerMap.get(fileName));
writer.close();
}
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNTrainer_main.py"));
}
@Test
public void testFullCfgGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
List<ConfigurationSymbol> configurations = new ArrayList<>();
List<String> instanceName = Arrays.asList("main_net1", "main_net2");
final ModelPath mp = new ModelPath(Paths.get("src/test/resources/valid_tests"));
GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage());
CNNTrainCompilationUnitSymbol compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("FullConfig", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("FullConfig2", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
CNNArch2Caffe2 generator = new CNNArch2Caffe2();
Map<String,String> trainerMap = generator.generateTrainer(configurations, instanceName, "mainFull");
for (String fileName : trainerMap.keySet()){
FileWriter writer = new FileWriter(generator.getGenerationTargetPath() + fileName);
writer.write(trainerMap.get(fileName));
writer.close();
}
String sourcePath = "src/test/resources/valid_tests";
CNNTrain2Caffe2 trainGenerator = new CNNTrain2Caffe2();
trainGenerator.generate(Paths.get(sourcePath), "FullConfig");
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNTrainer_mainFull.py"));
"CNNTrainer_fullConfig.py"));
}
@Test
public void testSimpleCfgGeneration() throws IOException {
Log.getFindings().clear();
List<ConfigurationSymbol> configurations = new ArrayList<>();
List<String> instanceName = Arrays.asList("main_net1", "main_net2");
final ModelPath mp = new ModelPath(Paths.get("src/test/resources/valid_tests"));
GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage());
CNNTrainCompilationUnitSymbol compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("SimpleConfig1", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
Path modelPath = Paths.get("src/test/resources/valid_tests");
CNNTrain2Caffe2 trainGenerator = new CNNTrain2Caffe2();
compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("SimpleConfig2", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
CNNArch2Caffe2 generator = new CNNArch2Caffe2();
Map<String,String> trainerMap = generator.generateTrainer(configurations, instanceName, "mainSimple");
for (String fileName : trainerMap.keySet()){
FileWriter writer = new FileWriter(generator.getGenerationTargetPath() + fileName);
writer.write(trainerMap.get(fileName));
writer.close();
}
trainGenerator.generate(modelPath, "SimpleConfig");
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNTrainer_mainSimple.py"));
"CNNTrainer_simpleConfig.py"));
}
@Test
public void testEmptyCfgGeneration() throws IOException {
Log.getFindings().clear();
List<ConfigurationSymbol> configurations = new ArrayList<>();
List<String> instanceName = Arrays.asList("main_net1");
final ModelPath mp = new ModelPath(Paths.get("src/test/resources/valid_tests"));
GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage());
CNNTrainCompilationUnitSymbol compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("EmptyConfig", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
CNNArch2Caffe2 generator = new CNNArch2Caffe2();
Map<String,String> trainerMap = generator.generateTrainer(configurations, instanceName, "mainEmpty");
for (String fileName : trainerMap.keySet()){
FileWriter writer = new FileWriter(generator.getGenerationTargetPath() + fileName);
writer.write(trainerMap.get(fileName));
writer.close();
}
Path modelPath = Paths.get("src/test/resources/valid_tests");
CNNTrain2Caffe2 trainGenerator = new CNNTrain2Caffe2();
trainGenerator.generate(modelPath, "EmptyConfig");
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNTrainer_mainEmpty.py"));
"CNNTrainer_emptyConfig.py"));
}
......
......@@ -12,6 +12,7 @@ set(INCLUDE_DIRS ${INCLUDE_DIRS} ${Armadillo_INCLUDE_DIRS})
set(LIBS ${LIBS} ${Armadillo_LIBRARIES})
# additional commands
set(LIBS ${LIBS} mxnet)
# create static library
include_directories(${INCLUDE_DIRS})
......@@ -24,4 +25,3 @@ set_target_properties(alexnet PROPERTIES LINKER_LANGUAGE CXX)
export(TARGETS alexnet FILE alexnet.cmake)
# additional commands end
set(LIBS ${LIBS} mxnet)
import logging
import mxnet as mx
import CNNCreator_main_net1
import CNNCreator_emptyConfig
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log","w", encoding=None, delay="true")
handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
logger.addHandler(handler)
main_net1 = CNNCreator_main_net1.CNNCreator_main_net1()
main_net1.train(
emptyConfig = CNNCreator_emptyConfig.CNNCreator_emptyConfig()
emptyConfig.train(
)
import logging
import mxnet as mx
import CNNCreator_main_net1
import CNNCreator_main_net2
import CNNCreator_fullConfig
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log","w", encoding=None, delay="true")
handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
logger.addHandler(handler)
main_net1 = CNNCreator_main_net1.CNNCreator_main_net1()
main_net1.train(
batch_size = 100,
num_epoch = 5,
load_checkpoint = True,
context = 'gpu',
normalize = True,
eval_metric = 'mse',
optimizer = 'rmsprop',
optimizer_params = {
fullConfig = CNNCreator_fullConfig.CNNCreator_fullConfig()
fullConfig.train(
batch_size=100,
num_epoch=5,
load_checkpoint=True,
context='gpu',
normalize=True,
eval_metric='mse',
optimizer='rmsprop',
optimizer_params={
'weight_decay': 0.01,
'centered': True,
'gamma2': 0.9,
......@@ -31,27 +30,5 @@ if __name__ == "__main__":
'learning_rate_minimum': 1.0E-5,
'learning_rate_policy': 'step',
'learning_rate': 0.001,
'step_size': 1000 }
)
main_net2 = CNNCreator_main_net2.CNNCreator_main_net2()
main_net2.train(
batch_size = 100,
num_epoch = 10,
load_checkpoint = False,
context = 'gpu',
normalize = False,
eval_metric = 'topKAccuracy',
optimizer = 'adam',
optimizer_params = {
'epsilon': 1.0E-6,
'weight_decay': 0.01,
'rescale_grad': 1.1,
'beta1': 0.9,
'clip_gradient': 10.0,
'beta2': 0.9,
'learning_rate_minimum': 0.001,
'learning_rate_policy': 'exp',
'learning_rate': 0.001,
'learning_rate_decay': 0.9,
'step_size': 1000 }
'step_size': 1000}
)
import logging
import mxnet as mx
import CNNCreator_main_net1
import CNNCreator_main_net2
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log","w", encoding=None, delay="true")
logger.addHandler(handler)
main_net1 = CNNCreator_main_net1.CNNCreator_main_net1()
main_net1.train(
batch_size = 64,
num_epoch = 10,
load_checkpoint = False,
context = 'gpu',
normalize = True,
optimizer = 'adam',
optimizer_params = {
'weight_decay': 1.0E-4,
'learning_rate': 0.01,
'learning_rate_decay': 0.8,
'step_size': 1000 }
)
main_net2 = CNNCreator_main_net2.CNNCreator_main_net2()
main_net2.train(
batch_size = 32,
num_epoch = 10,
load_checkpoint = False,
context = 'gpu',
normalize = True,
optimizer = 'adam',
optimizer_params = {
'weight_decay': 1.0E-4,
'learning_rate': 0.01,
'learning_rate_decay': 0.8,
'step_size': 1000 }
)
import logging
import mxnet as mx
import CNNCreator_main_net1
import CNNCreator_main_net2
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log","w", encoding=None, delay="true")
logger.addHandler(handler)
main_net1 = CNNCreator_main_net1.CNNCreator_main_net1()
main_net1.train(
batch_size = 100,
num_epoch = 50,
optimizer = 'adam',
optimizer_params = {
'learning_rate': 0.001 }
)
main_net2 = CNNCreator_main_net2.CNNCreator_main_net2()
main_net2.train(
batch_size = 100,
num_epoch = 5,
optimizer = 'sgd',
optimizer_params = {
'learning_rate': 0.1 }
)
import logging
import mxnet as mx
import CNNCreator_simpleConfig
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
logger.addHandler(handler)
simpleConfig = CNNCreator_simpleConfig.CNNCreator_simpleConfig()
simpleConfig.train(
batch_size=100,
num_epoch=50,
optimizer='adam',
optimizer_params={
'learning_rate': 0.001}
)
configuration FullConfig2{
num_epoch : 10
batch_size : 100
load_checkpoint : false
context : gpu
eval_metric : top_k_accuracy
normalize : false
optimizer : adam{
learning_rate : 0.001
learning_rate_minimum : 0.001
weight_decay : 0.01
learning_rate_decay : 0.9
learning_rate_policy : exp
step_size : 1000
rescale_grad : 1.1
clip_gradient : 10