diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java index f8e1804f5a899feaf5121b0559b0c2c7accd6b33..36ebd69ad692f562d1f5eecc602f65aae7bc5e0c 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java @@ -51,18 +51,18 @@ public class CNNArch2Gluon implements CNNArchGenerator { private boolean isSupportedLayer(ArchitectureElementSymbol element, LayerSupportChecker layerChecker){ List constructLayerElemList; - if (!(element instanceof IOSymbol) && (element.getResolvedThis().get() instanceof CompositeElementSymbol)) - { + if (!(element instanceof IOSymbol) && (element.getResolvedThis().get() instanceof CompositeElementSymbol)) { constructLayerElemList = ((CompositeElementSymbol)element.getResolvedThis().get()).getElements(); for (ArchitectureElementSymbol constructedLayerElement : constructLayerElemList) { - if (!isSupportedLayer(constructedLayerElement, layerChecker)) return false; + if (!isSupportedLayer(constructedLayerElement, layerChecker)) { + return false; + } } } if (!layerChecker.isSupported(element.toString())) { Log.error("Unsupported layer " + "'" + element.getName() + "'" + " for the backend MXNET."); return false; - } - else { + } else { return true; } } @@ -70,11 +70,18 @@ public class CNNArch2Gluon implements CNNArchGenerator { private boolean supportCheck(ArchitectureSymbol architecture){ LayerSupportChecker layerChecker = new LayerSupportChecker(); for (ArchitectureElementSymbol element : ((CompositeElementSymbol)architecture.getBody()).getElements()){ - if(!isSupportedLayer(element, layerChecker)) return false; + if(!isSupportedLayer(element, layerChecker)) { + return false; + } } return true; } + private static void quitGeneration(){ + Log.error("Code generation is aborted"); + System.exit(1); + } + public CNNArch2Gluon() { setGenerationTargetPath("./target/generated-sources-cnnarch/"); } @@ -105,19 +112,17 @@ public class CNNArch2Gluon implements CNNArchGenerator { Optional compilationUnit = scope.resolve(rootModelName, CNNArchCompilationUnitSymbol.KIND); if (!compilationUnit.isPresent()){ Log.error("could not resolve architecture " + rootModelName); - System.exit(1); + quitGeneration(); } CNNArchCocos.checkAll(compilationUnit.get()); if (!supportCheck(compilationUnit.get().getArchitecture())){ - Log.error("Code generation aborted."); - System.exit(1); + quitGeneration(); } try{ generateFiles(compilationUnit.get().getArchitecture()); - } - catch (IOException e){ + } catch (IOException e){ Log.error(e.toString()); } } @@ -178,7 +183,7 @@ public class CNNArch2Gluon implements CNNArchGenerator { try { generateFromFilecontentsMap(fileContentMap); } catch (IOException e) { - e.printStackTrace(); + Log.error("CMake file could not be generated" + e.getMessage()); } } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonCli.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonCli.java index cbe92a677c7a2e1311c72ab2a20af4eec13ae852..582053ff700ed5c6017e8ca0c1a8b0488353a6f7 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonCli.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonCli.java @@ -19,6 +19,7 @@ * ******************************************************************************* */ package de.monticore.lang.monticar.cnnarch.gluongenerator; +import de.se_rwth.commons.logging.Log; import org.apache.commons.cli.*; @@ -73,13 +74,18 @@ public class CNNArch2GluonCli { try { cliArgs = parser.parse(options, args); } catch (ParseException e) { - System.err.println("argument parsing exception: " + e.getMessage()); - System.exit(1); + Log.error("argument parsing exception: " + e.getMessage()); + quitGeneration(); return null; } return cliArgs; } + private static void quitGeneration(){ + Log.error("Code generation is aborted"); + System.exit(1); + } + private static void runGenerator(CommandLine cliArgs) { Path modelsDirPath = Paths.get(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt())); String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt()); diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArchTemplateController.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArchTemplateController.java index 3f1716f05b9934493ef22595c68ec596b8018043..c5d7a5585c8358ba109fdc81590f9a499e140bf1 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArchTemplateController.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArchTemplateController.java @@ -92,8 +92,7 @@ public class CNNArchTemplateController { if (isSoftmaxOutput(layer) || isLogisticRegressionOutput(layer)){ inputNames = getLayerInputs(layer.getInputElement().get()); - } - else { + } else { for (ArchitectureElementSymbol input : layer.getPrevious()) { if (input.getOutputTypes().size() == 1) { inputNames.add(getName(input)); @@ -148,12 +147,10 @@ public class CNNArchTemplateController { if (ioElement.isAtomic()){ if (ioElement.isInput()){ include(TEMPLATE_ELEMENTS_DIR_PATH, "Input", writer, netDefinitionMode); - } - else { + } else { include(TEMPLATE_ELEMENTS_DIR_PATH, "Output", writer, netDefinitionMode); } - } - else { + } else { include(ioElement.getResolvedThis().get(), writer, netDefinitionMode); } @@ -170,8 +167,7 @@ public class CNNArchTemplateController { String templateName = layer.getDeclaration().getName(); include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer, netDefinitionMode); } - } - else { + } else { include(layer.getResolvedThis().get(), writer, netDefinitionMode); } @@ -192,11 +188,9 @@ public class CNNArchTemplateController { public void include(ArchitectureElementSymbol architectureElement, Writer writer, NetDefinitionMode netDefinitionMode){ if (architectureElement instanceof CompositeElementSymbol){ include((CompositeElementSymbol) architectureElement, writer, netDefinitionMode); - } - else if (architectureElement instanceof LayerSymbol){ + } else if (architectureElement instanceof LayerSymbol){ include((LayerSymbol) architectureElement, writer, netDefinitionMode); - } - else { + } else { include((IOSymbol) architectureElement, writer, netDefinitionMode); } } @@ -213,15 +207,15 @@ public class CNNArchTemplateController { } public Map.Entry process(String templateNameWithoutEnding, Target targetLanguage){ - StringWriter writer = new StringWriter(); + StringWriter newWriter = new StringWriter(); this.mainTemplateNameWithoutEnding = templateNameWithoutEnding; this.targetLanguage = targetLanguage; - this.writer = writer; + this.writer = newWriter; - include("", templateNameWithoutEnding, writer); + include("", templateNameWithoutEnding, newWriter); String fileEnding = targetLanguage.toString(); String fileName = getFileNameWithoutEnding() + fileEnding; - Map.Entry fileContent = new AbstractMap.SimpleEntry<>(fileName, writer.toString()); + Map.Entry fileContent = new AbstractMap.SimpleEntry<>(fileName, newWriter.toString()); this.mainTemplateNameWithoutEnding = null; this.targetLanguage = null; @@ -265,12 +259,12 @@ public class CNNArchTemplateController { } private boolean isTOutput(Class inputPredefinedLayerClass, ArchitectureElementSymbol architectureElement){ - if (architectureElement.isOutput()){ - if (architectureElement.getInputElement().isPresent() && architectureElement.getInputElement().get() instanceof LayerSymbol){ - LayerSymbol inputLayer = (LayerSymbol) architectureElement.getInputElement().get(); - if (inputPredefinedLayerClass.isInstance(inputLayer.getDeclaration())){ - return true; - } + if (architectureElement.isOutput() + && architectureElement.getInputElement().isPresent() + && architectureElement.getInputElement().get() instanceof LayerSymbol){ + LayerSymbol inputLayer = (LayerSymbol) architectureElement.getInputElement().get(); + if (inputPredefinedLayerClass.isInstance(inputLayer.getDeclaration())){ + return true; } } return false; diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java index 00728a457ce823fb41ba3f7375341e9486a3a6b4..845f6b4705134853959a86a663d22d97c6b0f716 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java @@ -37,7 +37,9 @@ public class CNNTrain2Gluon implements CNNTrainGenerator { it = configuration.getEntryMap().keySet().iterator(); while (it.hasNext()) { String key = it.next().toString(); - if (funcChecker.getUnsupportedElemList().contains(key)) it.remove(); + if (funcChecker.getUnsupportedElemList().contains(key)) { + it.remove(); + } } } @@ -52,12 +54,19 @@ public class CNNTrain2Gluon implements CNNTrainGenerator { Iterator it = configuration.getOptimizer().getOptimizerParamMap().keySet().iterator(); while (it.hasNext()) { String key = it.next().toString(); - if (funcChecker.getUnsupportedElemList().contains(key)) it.remove(); + if (funcChecker.getUnsupportedElemList().contains(key)) { + it.remove(); + } } } } } + private static void quitGeneration(){ + Log.error("Code generation is aborted"); + System.exit(1); + } + public CNNTrain2Gluon() { setGenerationTargetPath("./target/generated-sources-cnnarch/"); } @@ -89,7 +98,7 @@ public class CNNTrain2Gluon implements CNNTrainGenerator { Optional compilationUnit = scope.resolve(rootModelName, CNNTrainCompilationUnitSymbol.KIND); if (!compilationUnit.isPresent()) { Log.error("could not resolve training configuration " + rootModelName); - System.exit(1); + quitGeneration(); } setInstanceName(compilationUnit.get().getFullName()); CNNTrainCocos.checkAll(compilationUnit.get()); @@ -107,7 +116,7 @@ public class CNNTrain2Gluon implements CNNTrainGenerator { genCPP.generateFile(new FileContent(fileContents.get(fileName), fileName)); } } catch (IOException e) { - e.printStackTrace(); + Log.error("CNNTrainer file could not be generated" + e.getMessage()); } } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ConfigurationData.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ConfigurationData.java index be686c368abe0dd76a45d27240b1094f635e1264..10cf8dd764d4b81ba97c615590441c8eed0e60d3 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ConfigurationData.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ConfigurationData.java @@ -89,8 +89,7 @@ public class ConfigurationData { Class realClass = entry.getValue().getValue().getValue().getClass(); if (realClass == Boolean.class) { valueAsString = (Boolean) entry.getValue().getValue().getValue() ? "True" : "False"; - } - else if (lrPolicyClasses.contains(realClass)) { + } else if (lrPolicyClasses.contains(realClass)) { valueAsString = "'" + valueAsString + "'"; } mapToStrings.put(paramName, valueAsString); diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/LayerNameCreator.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/LayerNameCreator.java index 1032ddaa021c62710841ec3b58ed14236cc70a58..0b91f1023e57a1b78350963bdc3a59d672ee1e1d 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/LayerNameCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/LayerNameCreator.java @@ -47,17 +47,14 @@ public class LayerNameCreator { protected int name(ArchitectureElementSymbol architectureElement, int stage, List streamIndices){ if (architectureElement instanceof CompositeElementSymbol){ return nameComposite((CompositeElementSymbol) architectureElement, stage, streamIndices); - } - else{ + } else{ if (architectureElement.isAtomic()){ if (architectureElement.getMaxSerialLength().get() > 0){ return add(architectureElement, stage, streamIndices); - } - else { + } else { return stage; } - } - else { + } else { ArchitectureElementSymbol resolvedElement = architectureElement.getResolvedThis().get(); return name(resolvedElement, stage, streamIndices); } @@ -78,8 +75,7 @@ public class LayerNameCreator { streamIndices.remove(lastIndex); return Collections.max(endStages) + 1; - } - else { + } else { int endStage = stage; for (ArchitectureElementSymbol subElement : compositeElement.getElements()){ endStage = name(subElement, endStage, streamIndices); @@ -113,8 +109,7 @@ public class LayerNameCreator { name = name + "_" + arrayAccess + "_"; } return name; - } - else { + } else { return createBaseName(architectureElement) + stage + createStreamPostfix(streamIndices) + "_"; } } @@ -132,11 +127,9 @@ public class LayerNameCreator { } else { return layerDeclaration.getName().toLowerCase(); } - } - else if (architectureElement instanceof CompositeElementSymbol){ + } else if (architectureElement instanceof CompositeElementSymbol){ return "group"; - } - else { + } else { return architectureElement.getName(); } } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/TemplateConfiguration.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/TemplateConfiguration.java index aa9897abb9f0b8af9067570793cbe4580bfb2ec3..d79e9076fc1308aef6e3f9f45b103432c8cd3ab5 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/TemplateConfiguration.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/TemplateConfiguration.java @@ -43,6 +43,11 @@ public class TemplateConfiguration { configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER); } + private static void quitGeneration(){ + Log.error("Code generation is aborted"); + System.exit(1); + } + public Configuration getConfiguration() { return configuration; } @@ -58,14 +63,12 @@ public class TemplateConfiguration { try{ Template template = TemplateConfiguration.get().getTemplate(templatePath); template.process(ftlContext, writer); - } - catch (IOException e) { + } catch (IOException e) { Log.error("Freemarker could not find template " + templatePath + " :\n" + e.getMessage()); - System.exit(1); - } - catch (TemplateException e){ + quitGeneration(); + } catch (TemplateException e){ Log.error("An exception occured in template " + templatePath + " :\n" + e.getMessage()); - System.exit(1); + quitGeneration(); } } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/TrainParamSupportChecker.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/TrainParamSupportChecker.java index e559eb4896f009225d688d24a03243bc1968036b..d145e6ff305a0cd72e938c0e4085bb7f118192ff 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/TrainParamSupportChecker.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/TrainParamSupportChecker.java @@ -25,12 +25,14 @@ public class TrainParamSupportChecker implements CNNTrainVisitor { public TrainParamSupportChecker() { } - public String unsupportedOptFlag = "unsupported_optimizer"; + public static final String unsupportedOptFlag = "unsupported_optimizer"; public List getUnsupportedElemList(){ return this.unsupportedElemList; } + //Empty visit method denotes that the corresponding training parameter is supported. + //To set a training parameter as unsupported, add the corresponding node to the unsupportedElemList public void visit(ASTNumEpochEntry node){} public void visit(ASTBatchSizeEntry node){}