Commit f9a53bef authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'move-to-gluon' into 'master'

Move to gluon

Closes #1

See merge request !1
parents 4afa996b 842fbaff
Pipeline #100989 passed with stages
in 4 minutes and 24 seconds
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
<!-- == PROJECT COORDINATES ============================================= --> <!-- == PROJECT COORDINATES ============================================= -->
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-mxnet-generator</artifactId> <artifactId>cnnarch-gluon-generator</artifactId>
<version>0.2.8</version> <version>0.1.0-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= --> <!-- == PROJECT DEPENDENCIES ============================================= -->
...@@ -172,7 +172,7 @@ ...@@ -172,7 +172,7 @@
<configuration> <configuration>
<archive> <archive>
<manifest> <manifest>
<mainClass>de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArch2MxNetCli</mainClass> <mainClass>de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2GluonCli</mainClass>
</manifest> </manifest>
</archive> </archive>
<descriptorRefs> <descriptorRefs>
...@@ -229,7 +229,8 @@ ...@@ -229,7 +229,8 @@
<maxmem>256m</maxmem> <maxmem>256m</maxmem>
<!-- aggregated reports for multi-module projects --> <!-- aggregated reports for multi-module projects -->
<aggregate>true</aggregate> <aggregate>true</aggregate>
</configuration> <check/>
</configuration>
</plugin> </plugin>
</plugins> </plugins>
</build> </build>
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>. * License along with this project. If not, see <http://www.gnu.org/licenses/>.
* ******************************************************************************* * *******************************************************************************
*/ */
package de.monticore.lang.monticar.cnnarch.mxnetgenerator; package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>. * License along with this project. If not, see <http://www.gnu.org/licenses/>.
* ******************************************************************************* * *******************************************************************************
*/ */
package de.monticore.lang.monticar.cnnarch.mxnetgenerator; package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.io.paths.ModelPath; import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator; import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
...@@ -40,11 +40,11 @@ import java.util.HashMap; ...@@ -40,11 +40,11 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
public class CNNArch2MxNet implements CNNArchGenerator { public class CNNArch2Gluon implements CNNArchGenerator {
private String generationTargetPath; private String generationTargetPath;
public CNNArch2MxNet() { public CNNArch2Gluon() {
setGenerationTargetPath("./target/generated-sources-cnnarch/"); setGenerationTargetPath("./target/generated-sources-cnnarch/");
} }
...@@ -96,6 +96,9 @@ public class CNNArch2MxNet implements CNNArchGenerator { ...@@ -96,6 +96,9 @@ public class CNNArch2MxNet implements CNNArchGenerator {
temp = archTc.process("CNNPredictor", Target.CPP); temp = archTc.process("CNNPredictor", Target.CPP);
fileContentMap.put(temp.getKey(), temp.getValue()); fileContentMap.put(temp.getKey(), temp.getValue());
temp = archTc.process("CNNNet", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
temp = archTc.process("CNNCreator", Target.PYTHON); temp = archTc.process("CNNCreator", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue()); fileContentMap.put(temp.getKey(), temp.getValue());
......
...@@ -18,14 +18,14 @@ ...@@ -18,14 +18,14 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>. * License along with this project. If not, see <http://www.gnu.org/licenses/>.
* ******************************************************************************* * *******************************************************************************
*/ */
package de.monticore.lang.monticar.cnnarch.mxnetgenerator; package de.monticore.lang.monticar.cnnarch.gluongenerator;
import org.apache.commons.cli.*; import org.apache.commons.cli.*;
import java.nio.file.Path; import java.nio.file.Path;
import java.nio.file.Paths; import java.nio.file.Paths;
public class CNNArch2MxNetCli { public class CNNArch2GluonCli {
public static final Option OPTION_MODELS_PATH = Option.builder("m") public static final Option OPTION_MODELS_PATH = Option.builder("m")
.longOpt("models-dir") .longOpt("models-dir")
...@@ -48,7 +48,7 @@ public class CNNArch2MxNetCli { ...@@ -48,7 +48,7 @@ public class CNNArch2MxNetCli {
.required(false) .required(false)
.build(); .build();
private CNNArch2MxNetCli() { private CNNArch2GluonCli() {
} }
public static void main(String[] args) { public static void main(String[] args) {
...@@ -84,7 +84,7 @@ public class CNNArch2MxNetCli { ...@@ -84,7 +84,7 @@ public class CNNArch2MxNetCli {
Path modelsDirPath = Paths.get(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt())); Path modelsDirPath = Paths.get(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt()));
String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt()); String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt());
String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt()); String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt());
CNNArch2MxNet generator = new CNNArch2MxNet(); CNNArch2Gluon generator = new CNNArch2Gluon();
if (outputPath != null){ if (outputPath != null){
generator.setGenerationTargetPath(outputPath); generator.setGenerationTargetPath(outputPath);
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>. * License along with this project. If not, see <http://www.gnu.org/licenses/>.
* ******************************************************************************* * *******************************************************************************
*/ */
package de.monticore.lang.monticar.cnnarch.mxnetgenerator; package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.*; import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.Sigmoid; import de.monticore.lang.monticar.cnnarch.predefined.Sigmoid;
...@@ -34,6 +34,7 @@ public class CNNArchTemplateController { ...@@ -34,6 +34,7 @@ public class CNNArchTemplateController {
public static final String TEMPLATE_ELEMENTS_DIR_PATH = "elements/"; public static final String TEMPLATE_ELEMENTS_DIR_PATH = "elements/";
public static final String TEMPLATE_CONTROLLER_KEY = "tc"; public static final String TEMPLATE_CONTROLLER_KEY = "tc";
public static final String ELEMENT_DATA_KEY = "element"; public static final String ELEMENT_DATA_KEY = "element";
public static final String NET_DEFINITION_MODE_KEY = "definition_mode";
private LayerNameCreator nameManager; private LayerNameCreator nameManager;
private ArchitectureSymbol architecture; private ArchitectureSymbol architecture;
...@@ -123,34 +124,43 @@ public class CNNArchTemplateController { ...@@ -123,34 +124,43 @@ public class CNNArchTemplateController {
return list; return list;
} }
public void include(String relativePath, String templateWithoutFileEnding, Writer writer){ public void include(String relativePath, String templateWithoutFileEnding, Writer writer, NetDefinitionMode netDefinitionMode){
String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING; String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING;
Map<String, Object> ftlContext = new HashMap<>(); Map<String, Object> ftlContext = new HashMap<>();
ftlContext.put(TEMPLATE_CONTROLLER_KEY, this); ftlContext.put(TEMPLATE_CONTROLLER_KEY, this);
ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement()); ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement());
ftlContext.put(NET_DEFINITION_MODE_KEY, netDefinitionMode);
TemplateConfiguration.processTemplate(ftlContext, templatePath, writer); TemplateConfiguration.processTemplate(ftlContext, templatePath, writer);
} }
public void include(IOSymbol ioElement, Writer writer){ public void include(String relativePath, String templateWithoutFileEnding, Writer writer) {
String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING;
Map<String, Object> ftlContext = new HashMap<>();
ftlContext.put(TEMPLATE_CONTROLLER_KEY, this);
ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement());
TemplateConfiguration.processTemplate(ftlContext, templatePath, writer);
}
public void include(IOSymbol ioElement, Writer writer, NetDefinitionMode netDefinitionMode){
ArchitectureElementData previousElement = getCurrentElement(); ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(ioElement); setCurrentElement(ioElement);
if (ioElement.isAtomic()){ if (ioElement.isAtomic()){
if (ioElement.isInput()){ if (ioElement.isInput()){
include(TEMPLATE_ELEMENTS_DIR_PATH, "Input", writer); include(TEMPLATE_ELEMENTS_DIR_PATH, "Input", writer, netDefinitionMode);
} }
else { else {
include(TEMPLATE_ELEMENTS_DIR_PATH, "Output", writer); include(TEMPLATE_ELEMENTS_DIR_PATH, "Output", writer, netDefinitionMode);
} }
} }
else { else {
include(ioElement.getResolvedThis().get(), writer); include(ioElement.getResolvedThis().get(), writer, netDefinitionMode);
} }
setCurrentElement(previousElement); setCurrentElement(previousElement);
} }
public void include(LayerSymbol layer, Writer writer){ public void include(LayerSymbol layer, Writer writer, NetDefinitionMode netDefinitionMode){
ArchitectureElementData previousElement = getCurrentElement(); ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(layer); setCurrentElement(layer);
...@@ -158,44 +168,48 @@ public class CNNArchTemplateController { ...@@ -158,44 +168,48 @@ public class CNNArchTemplateController {
ArchitectureElementSymbol nextElement = layer.getOutputElement().get(); ArchitectureElementSymbol nextElement = layer.getOutputElement().get();
if (!isSoftmaxOutput(nextElement) && !isLogisticRegressionOutput(nextElement)){ if (!isSoftmaxOutput(nextElement) && !isLogisticRegressionOutput(nextElement)){
String templateName = layer.getDeclaration().getName(); String templateName = layer.getDeclaration().getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer); include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer, netDefinitionMode);
} }
} }
else { else {
include(layer.getResolvedThis().get(), writer); include(layer.getResolvedThis().get(), writer, netDefinitionMode);
} }
setCurrentElement(previousElement); setCurrentElement(previousElement);
} }
public void include(CompositeElementSymbol compositeElement, Writer writer){ public void include(CompositeElementSymbol compositeElement, Writer writer, NetDefinitionMode netDefinitionMode){
ArchitectureElementData previousElement = getCurrentElement(); ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(compositeElement); setCurrentElement(compositeElement);
for (ArchitectureElementSymbol element : compositeElement.getElements()){ for (ArchitectureElementSymbol element : compositeElement.getElements()){
include(element, writer); include(element, writer, netDefinitionMode);
} }
setCurrentElement(previousElement); setCurrentElement(previousElement);
} }
public void include(ArchitectureElementSymbol architectureElement, Writer writer){ public void include(ArchitectureElementSymbol architectureElement, Writer writer, NetDefinitionMode netDefinitionMode){
if (architectureElement instanceof CompositeElementSymbol){ if (architectureElement instanceof CompositeElementSymbol){
include((CompositeElementSymbol) architectureElement, writer); include((CompositeElementSymbol) architectureElement, writer, netDefinitionMode);
} }
else if (architectureElement instanceof LayerSymbol){ else if (architectureElement instanceof LayerSymbol){
include((LayerSymbol) architectureElement, writer); include((LayerSymbol) architectureElement, writer, netDefinitionMode);
} }
else { else {
include((IOSymbol) architectureElement, writer); include((IOSymbol) architectureElement, writer, netDefinitionMode);
} }
} }
public void include(ArchitectureElementSymbol architectureElement){ public void include(ArchitectureElementSymbol architectureElementSymbol, String netDefinitionMode) {
include(architectureElementSymbol, NetDefinitionMode.fromString(netDefinitionMode));
}
public void include(ArchitectureElementSymbol architectureElement, NetDefinitionMode netDefinitionMode){
if (writer == null){ if (writer == null){
throw new IllegalStateException("missing writer"); throw new IllegalStateException("missing writer");
} }
include(architectureElement, writer); include(architectureElement, writer, netDefinitionMode);
} }
public Map.Entry<String,String> process(String templateNameWithoutEnding, Target targetLanguage){ public Map.Entry<String,String> process(String templateNameWithoutEnding, Target targetLanguage){
......
package de.monticore.lang.monticar.cnnarch.mxnetgenerator; package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.io.paths.ModelPath; import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnntrain.CNNTrainGenerator; import de.monticore.lang.monticar.cnntrain.CNNTrainGenerator;
...@@ -17,7 +17,7 @@ import java.io.IOException; ...@@ -17,7 +17,7 @@ import java.io.IOException;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.*; import java.util.*;
public class CNNTrain2MxNet implements CNNTrainGenerator { public class CNNTrain2Gluon implements CNNTrainGenerator {
private String generationTargetPath; private String generationTargetPath;
private String instanceName; private String instanceName;
...@@ -58,7 +58,7 @@ public class CNNTrain2MxNet implements CNNTrainGenerator { ...@@ -58,7 +58,7 @@ public class CNNTrain2MxNet implements CNNTrainGenerator {
} }
} }
public CNNTrain2MxNet() { public CNNTrain2Gluon() {
setGenerationTargetPath("./target/generated-sources-cnnarch/"); setGenerationTargetPath("./target/generated-sources-cnnarch/");
} }
......
package de.monticore.lang.monticar.cnnarch.mxnetgenerator; package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnntrain._symboltable.*; import de.monticore.lang.monticar.cnntrain._symboltable.*;
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>. * License along with this project. If not, see <http://www.gnu.org/licenses/>.
* ******************************************************************************* * *******************************************************************************
*/ */
package de.monticore.lang.monticar.cnnarch.mxnetgenerator; package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.*; import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.Convolution; import de.monticore.lang.monticar.cnnarch.predefined.Convolution;
......
package de.monticore.lang.monticar.cnnarch.gluongenerator;
/**
*
*/
public enum NetDefinitionMode {
ARCHITECTURE_DEFINITION,
FORWARD_FUNCTION;
public static NetDefinitionMode fromString(final String netDefinitionMode) {
switch(netDefinitionMode) {
case "ARCHITECTURE_DEFINITION":
return ARCHITECTURE_DEFINITION;
case "FORWARD_FUNCTION":
return FORWARD_FUNCTION;
default:
throw new IllegalArgumentException("Unknown Net Definition Mode");
}
}
}
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>. * License along with this project. If not, see <http://www.gnu.org/licenses/>.
* ******************************************************************************* * *******************************************************************************
*/ */
package de.monticore.lang.monticar.cnnarch.mxnetgenerator; package de.monticore.lang.monticar.cnnarch.gluongenerator;
//can be removed //can be removed
public enum Target { public enum Target {
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>. * License along with this project. If not, see <http://www.gnu.org/licenses/>.
* ******************************************************************************* * *******************************************************************************
*/ */
package de.monticore.lang.monticar.cnnarch.mxnetgenerator; package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.se_rwth.commons.logging.Log; import de.se_rwth.commons.logging.Log;
import freemarker.template.Configuration; import freemarker.template.Configuration;
...@@ -38,7 +38,7 @@ public class TemplateConfiguration { ...@@ -38,7 +38,7 @@ public class TemplateConfiguration {
private TemplateConfiguration() { private TemplateConfiguration() {
configuration = new Configuration(Configuration.VERSION_2_3_23); configuration = new Configuration(Configuration.VERSION_2_3_23);
configuration.setClassForTemplateLoading(TemplateConfiguration.class, "/templates/mxnet/"); configuration.setClassForTemplateLoading(TemplateConfiguration.class, "/templates/gluon/");
configuration.setDefaultEncoding("UTF-8"); configuration.setDefaultEncoding("UTF-8");
configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER); configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER);
} }
......
package de.monticore.lang.monticar.cnnarch.mxnetgenerator; package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnntrain._ast.*; import de.monticore.lang.monticar.cnntrain._ast.*;
import de.monticore.lang.monticar.cnntrain._visitor.CNNTrainVisitor; import de.monticore.lang.monticar.cnntrain._visitor.CNNTrainVisitor;
......
...@@ -6,6 +6,9 @@ import shutil ...@@ -6,6 +6,9 @@ import shutil
import h5py import h5py
import sys import sys
import numpy as np import numpy as np
import time
from mxnet import gluon, autograd, nd
from CNNNet_${tc.fullArchitectureName} import Net
@mx.init.register @mx.init.register
class MyConstant(mx.init.Initializer): class MyConstant(mx.init.Initializer):
...@@ -17,7 +20,6 @@ class MyConstant(mx.init.Initializer): ...@@ -17,7 +20,6 @@ class MyConstant(mx.init.Initializer):
class ${tc.fileNameWithoutEnding}: class ${tc.fileNameWithoutEnding}:
module = None
_data_dir_ = "data/${tc.fullArchitectureName}/" _data_dir_ = "data/${tc.fullArchitectureName}/"
_model_dir_ = "model/${tc.fullArchitectureName}/" _model_dir_ = "model/${tc.fullArchitectureName}/"
_model_prefix_ = "${tc.architectureName}" _model_prefix_ = "${tc.architectureName}"
...@@ -25,6 +27,9 @@ class ${tc.fileNameWithoutEnding}: ...@@ -25,6 +27,9 @@ class ${tc.fileNameWithoutEnding}:
_input_shapes_ = [<#list tc.architecture.inputs as input>(${tc.join(input.definition.type.dimensions, ",")})</#list>] _input_shapes_ = [<#list tc.architecture.inputs as input>(${tc.join(input.definition.type.dimensions, ",")})</#list>]
_output_names_ = [${tc.join(tc.architectureOutputs, ",", "'", "_label'")}] _output_names_ = [${tc.join(tc.architectureOutputs, ",", "'", "_label'")}]
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.net = None
def load(self, context): def load(self, context):
lastEpoch = 0 lastEpoch = 0
...@@ -51,11 +56,7 @@ class ${tc.fileNameWithoutEnding}: ...@@ -51,11 +56,7 @@ class ${tc.fileNameWithoutEnding}:
return 0 return 0
else: else:
logging.info("Loading checkpoint: " + param_file) logging.info("Loading checkpoint: " + param_file)
self.module.load(prefix=self._model_dir_ + self._model_prefix_, self.net.load_parameters(param_file)
epoch=lastEpoch,
data_names=self._input_names_,
label_names=self._output_names_,
context=context)
return lastEpoch return lastEpoch
...@@ -138,11 +139,11 @@ class ${tc.fileNameWithoutEnding}: ...@@ -138,11 +139,11 @@ class ${tc.fileNameWithoutEnding}:
train_iter, test_iter, data_mean, data_std = self.load_data(batch_size) train_iter, test_iter, data_mean, data_std = self.load_data(batch_size)
if self.module == None: if self.net == None:
if normalize: if normalize:
self.construct(mx_context, data_mean, data_std) self.construct(context=mx_context, data_mean=nd.array(data_mean), data_std=nd.array(data_std))
else: else:
self.construct(mx_context) self.construct(context=mx_context)
begin_epoch = 0 begin_epoch = 0
if load_checkpoint: if load_checkpoint:
...@@ -157,23 +158,79 @@ class ${tc.fileNameWithoutEnding}: ...@@ -157,23 +158,79 @@ class ${tc.fileNameWithoutEnding}:
if not os.path.isdir(self._model_dir_): if not os.path.isdir(self._model_dir_):
raise raise
self.module.fit( trainer = mx.gluon.Trainer(self.net.collect_params(), optimizer, optimizer_params)
train_data=train_iter,
eval_metric=eval_metric, if self.net.last_layer == 'softmax':
eval_data=test_iter, loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss()
optimizer=optimizer,