Added generation of CNNTrainer.

parent 29d922ab
Pipeline #55976 failed with stages
in 13 seconds
......@@ -35,10 +35,7 @@ import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.*;
public class CNNArch2MxNet implements CNNArchGenerator {
......@@ -83,8 +80,21 @@ public class CNNArch2MxNet implements CNNArchGenerator {
}
@Override
public Map<String, String> generateTrainer(List<ConfigurationSymbol> configurations, List<String> instanceNames) {
return null;
public Map<String, String> generateTrainer(List<ConfigurationSymbol> configurations, List<String> instanceNames, String mainComponentName) {
int numberOfNetworks = configurations.size();
if (configurations.size() != instanceNames.size()){
throw new IllegalStateException(
"The number of configurations and the number of instances for generation of the CNNTrainer is not equal. " +
"This should have been checked previously.");
}
List<ConfigurationData> configDataList = new ArrayList<>();
for(int i = 0; i < numberOfNetworks; i++){
configDataList.add(new ConfigurationData(configurations.get(i), instanceNames.get(i)));
}
Map<String, Object> ftlContext = Collections.singletonMap("configurations", configDataList);
return Collections.singletonMap(
"CNNTrainer_" + mainComponentName + ".py",
TemplateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl"));
}
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
......@@ -112,15 +122,13 @@ public class CNNArch2MxNet implements CNNArchGenerator {
private void checkValidGeneration(ArchitectureSymbol architecture){
if (architecture.getInputs().size() > 1){
Log.warn("This cnn architecture has multiple inputs, " +
"which is currently not supported by the generator. " +
"The generated code will not work correctly."
Log.error("This cnn architecture has multiple inputs, " +
"which is currently not supported by the generator. "
, architecture.getSourcePosition());
}
if (architecture.getOutputs().size() > 1){
Log.warn("This cnn architecture has multiple outputs, " +
"which is currently not supported by the generator. " +
"The generated code will not work correctly."
Log.error("This cnn architecture has multiple outputs, " +
"which is currently not supported by the generator. "
, architecture.getSourcePosition());
}
if (architecture.getOutputs().get(0).getDefinition().getType().getWidth() != 1 ||
......@@ -151,5 +159,4 @@ public class CNNArch2MxNet implements CNNArchGenerator {
writer.close();
}
}
}
......@@ -23,12 +23,7 @@ package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.Sigmoid;
import de.monticore.lang.monticar.cnnarch.predefined.Softmax;
import de.se_rwth.commons.logging.Log;
import freemarker.template.Configuration;
import freemarker.template.Template;
import freemarker.template.TemplateException;
import java.io.IOException;
import java.io.StringWriter;
import java.io.Writer;
import java.util.*;
......@@ -41,14 +36,15 @@ public class CNNArchTemplateController {
public static final String ELEMENT_DATA_KEY = "element";
private LayerNameCreator nameManager;
private Configuration freemarkerConfig = TemplateConfiguration.get();
private ArchitectureSymbol architecture;
//temporary attributes. They are set after calling process()
private Writer writer;
private String mainTemplateNameWithoutEnding;
private Target targetLanguage;
private ArchitectureElementData dataElement;
public CNNArchTemplateController(ArchitectureSymbol architecture) {
setArchitecture(architecture);
}
......@@ -57,14 +53,6 @@ public class CNNArchTemplateController {
return mainTemplateNameWithoutEnding + "_" + getFullArchitectureName();
}
public Target getTargetLanguage(){
return targetLanguage;
}
public void setTargetLanguage(Target targetLanguage) {
this.targetLanguage = targetLanguage;
}
public ArchitectureElementData getCurrentElement() {
return dataElement;
}
......@@ -137,25 +125,10 @@ public class CNNArchTemplateController {
public void include(String relativePath, String templateWithoutFileEnding, Writer writer){
String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING;
try {
Template template = freemarkerConfig.getTemplate(templatePath);
Map<String, Object> ftlContext = new HashMap<>();
ftlContext.put(TEMPLATE_CONTROLLER_KEY, this);
ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement());
this.writer = writer;
template.process(ftlContext, writer);
this.writer = null;
}
catch (IOException e) {
Log.error("Freemarker could not find template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
}
catch (TemplateException e){
Log.error("An exception occured in template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
}
Map<String, Object> ftlContext = new HashMap<>();
ftlContext.put(TEMPLATE_CONTROLLER_KEY, this);
ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement());
TemplateConfiguration.processTemplate(ftlContext, templatePath, writer);
}
public void include(IOSymbol ioElement, Writer writer){
......@@ -229,18 +202,16 @@ public class CNNArchTemplateController {
StringWriter writer = new StringWriter();
this.mainTemplateNameWithoutEnding = templateNameWithoutEnding;
this.targetLanguage = targetLanguage;
include("", templateNameWithoutEnding, writer);
this.writer = writer;
include("", templateNameWithoutEnding, writer);
String fileEnding = targetLanguage.toString();
if (targetLanguage == Target.CPP){
fileEnding = ".h";
}
String fileName = getFileNameWithoutEnding() + fileEnding;
Map.Entry<String,String> fileContent = new AbstractMap.SimpleEntry<>(fileName, writer.toString());
this.mainTemplateNameWithoutEnding = null;
this.targetLanguage = null;
this.writer = null;
return fileContent;
}
......
package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class ConfigurationData {
ConfigurationSymbol configuration;
String instanceName;
public ConfigurationData(ConfigurationSymbol configuration, String instanceName) {
this.configuration = configuration;
this.instanceName = instanceName;
}
public ConfigurationSymbol getConfiguration() {
return configuration;
}
public String getInstanceName() {
return instanceName;
}
public String getNumEpoch() {
return String.valueOf(getConfiguration().getNumEpoch().getValue());
}
public String getBatchSize() {
return String.valueOf(getConfiguration().getBatchSize().getValue());
}
public LoadCheckpointSymbol getLoadCheckpoint() {
return getConfiguration().getLoadCheckpoint();
}
public NormalizeSymbol getNormalize() {
return getConfiguration().getNormalize();
}
public TrainContextSymbol getContext() {
return getConfiguration().getTrainContext();
}
public String getOptimizerName() {
return getConfiguration().getOptimizer().getName();
}
public Map<String, String> getOptimizerParams() {
// get classes for single enum values
List<Class> lrPolicyClasses = new ArrayList<>();
for (LRPolicy enum_value: LRPolicy.values()) {
lrPolicyClasses.add(enum_value.getClass());
}
Map<String, String> mapToStrings = new HashMap<>();
Map<String, OptimizerParamSymbol> optimizerParams = getConfiguration().getOptimizer().getOptimizerParamMap();
for (Map.Entry<String, OptimizerParamSymbol> entry : optimizerParams.entrySet()) {
String paramName = entry.getKey();
String valueAsString = entry.getValue().toString();
Class realClass = entry.getValue().getValue().getValue().getClass();
if (realClass == Boolean.class) {
valueAsString = (Boolean) entry.getValue().getValue().getValue() ? "True" : "False";
}
else if (lrPolicyClasses.contains(realClass)) {
valueAsString = "'" + valueAsString + "'";
}
mapToStrings.put(paramName, valueAsString);
}
return mapToStrings;
}
}
......@@ -31,26 +31,7 @@ public enum Target {
CPP{
@Override
public String toString() {
return ".cpp";
return ".h";
}
};
public static Target fromString(String target){
switch (target.toLowerCase()){
case "python":
return PYTHON;
case "py":
return PYTHON;
case "cpp":
return CPP;
case "c++":
return CPP;
default:
throw new IllegalArgumentException();
}
}
}
......@@ -20,9 +20,17 @@
*/
package de.monticore.lang.monticar.cnnarch.generator;
import de.se_rwth.commons.logging.Log;
import freemarker.template.Configuration;
import freemarker.template.Template;
import freemarker.template.TemplateException;
import freemarker.template.TemplateExceptionHandler;
import java.io.IOException;
import java.io.StringWriter;
import java.io.Writer;
import java.util.Map;
public class TemplateConfiguration {
private static TemplateConfiguration instance;
......@@ -46,4 +54,25 @@ public class TemplateConfiguration {
return instance.getConfiguration();
}
public static void processTemplate(Map<String, Object> ftlContext, String templatePath, Writer writer){
try{
Template template = TemplateConfiguration.get().getTemplate(templatePath);
template.process(ftlContext, writer);
}
catch (IOException e) {
Log.error("Freemarker could not find template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
}
catch (TemplateException e){
Log.error("An exception occured in template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
}
}
public static String processTemplate(Map<String, Object> ftlContext, String templatePath){
StringWriter writer = new StringWriter();
processTemplate(ftlContext, templatePath, writer);
return writer.toString();
}
}
import logging
import mxnet as mx
<#list configurations as config>
import CNNCreator_${config.instanceName}
</#list>
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log","w", encoding=None, delay="true")
logger.addHandler(handler)
<#list configurations as config>
${config.instanceName} = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
${config.instanceName}.train(
<#if (config.batchSize)??>
batch_size = ${config.batchSize},
</#if>
<#if (config.loadCheckpoint)??>
load_checkpoint = ${config.loadCheckpoint.value?string("True","False")},
</#if>
<#if (config.context)??>
context = '${config.context.value}',
</#if>
<#if (config.normalize)??>
normalize = ${config.normalize.value?string("True","False")},
</#if>
<#if (config.configuration.optimizer)??>
optimizer = '${config.optimizerName}',
optimizer_params = {
<#list config.optimizerParams?keys as param>
'${param}': ${config.optimizerParams[param]}<#sep>,
</#list>
}
</#if>
)
</#list>
\ No newline at end of file
......@@ -20,15 +20,23 @@
*/
package de.monticore.lang.monticar.cnnarch;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnnarch.generator.CNNArch2MxNet;
import de.monticore.lang.monticar.cnnarch.generator.CNNArch2MxNetCli;
import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.symboltable.GlobalScope;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
import org.junit.Before;
import org.junit.Test;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.*;
import static junit.framework.TestCase.assertTrue;
......@@ -114,4 +122,40 @@ public class GenerationTest extends AbstractSymtabTest{
CNNArch2MxNetCli.main(args);
assertTrue(Log.getFindings().size() == 3);
}
@Test
public void testCNNTrainerGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
List<ConfigurationSymbol> configurations = new ArrayList<>();
List<String> instanceNames = Arrays.asList("main_net1", "main_net2");
final ModelPath mp = new ModelPath(Paths.get("src/test/resources/valid_tests"));
GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage());
CNNTrainCompilationUnitSymbol compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("Network1", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("Network2", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
CNNArch2MxNet generator = new CNNArch2MxNet();
Map<String,String> trainerMap = generator.generateTrainer(configurations, instanceNames, "main");
for (String fileName : trainerMap.keySet()){
FileWriter writer = new FileWriter(generator.getGenerationTargetPath() + fileName);
writer.write(trainerMap.get(fileName));
writer.close();
}
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNTrainer_main.py"));
}
}
import logging
import mxnet as mx
import CNNCreator_main_net1
import CNNCreator_main_net2
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log","w", encoding=None, delay="true")
logger.addHandler(handler)
main_net1 = CNNCreator_main_net1.CNNCreator_main_net1()
main_net1.train(
batch_size = 64,
load_checkpoint = False,
context = 'gpu',
normalize = True,
optimizer = 'adam',
optimizer_params = {
'weight_decay': 1.0E-4,
'learning_rate': 0.01,
'learning_rate_decay': 0.8,
'step_size': 1000 }
)
main_net2 = CNNCreator_main_net2.CNNCreator_main_net2()
main_net2.train(
batch_size = 32,
load_checkpoint = False,
context = 'gpu',
normalize = True,
optimizer = 'adam',
optimizer_params = {
'weight_decay': 1.0E-4,
'learning_rate': 0.01,
'learning_rate_decay': 0.8,
'step_size': 1000 }
)
configuration Network1{
num_epoch:10
batch_size:64
normalize:true
context:gpu
load_checkpoint:false
optimizer:adam{
learning_rate:0.01
learning_rate_decay:0.8
step_size:1000
weight_decay:0.0001
}
}
configuration Network2{
num_epoch:10
batch_size:32
normalize:true
context:gpu
load_checkpoint:false
optimizer:adam{
learning_rate:0.01
learning_rate_decay:0.8
step_size:1000
weight_decay:0.0001
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment