diff --git a/pom.xml b/pom.xml index 007cea102930abef4a921ddc0d977b76dc812334..d8e5ee9b4ece21418522e592d6ac2a89a2c70270 100644 --- a/pom.xml +++ b/pom.xml @@ -15,7 +15,7 @@ - 0.2.6 + 0.2.7-SNAPSHOT 0.2.6 0.2.14-SNAPSHOT 0.2.11-SNAPSHOT @@ -62,6 +62,12 @@ + + de.monticore.lang.monticar + common-monticar + 0.0.17-SNAPSHOT + + de.monticore.lang.monticar embedded-montiarc-math-opt-generator diff --git a/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLGenerator.java b/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLGenerator.java index f48e6cf62ffa67df7600ea1b5c0df22be54332aa..6890b23ccae2ba74b87d82bce63e08a36cf354a5 100644 --- a/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLGenerator.java +++ b/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLGenerator.java @@ -26,7 +26,7 @@ import com.google.common.base.Splitter; import com.google.common.io.Resources; import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.cncModel.EMAComponentSymbol; import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol; -import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstantiationSymbol; +import de.monticore.lang.embeddedmontiarcdynamic.embeddedmontiarcdynamic._symboltable.cncModel.EMADynamicComponentSymbol; import de.monticore.lang.math._symboltable.MathStatementsSymbol; import de.monticore.lang.monticar.cnnarch.CNNArchGenerator; import de.monticore.lang.monticar.cnnarch.DataPathConfigParser; @@ -34,7 +34,6 @@ import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol; import de.monticore.lang.monticar.cnntrain.CNNTrainGenerator; import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; import de.monticore.lang.monticar.emadl._cocos.EMADLCocos; -import de.monticore.lang.monticar.emadl._cocos.DataPathCocos; import de.monticore.lang.monticar.emadl.tagging.dltag.DataPathSymbol; import de.monticore.lang.monticar.generator.FileContent; import de.monticore.lang.monticar.generator.cpp.ArmadilloHelper; @@ -49,15 +48,18 @@ import de.se_rwth.commons.Splitters; import de.se_rwth.commons.logging.Log; import freemarker.template.TemplateException; -import javax.xml.bind.DatatypeConverter; import java.io.*; import java.nio.charset.Charset; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.*; + import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; -import java.util.*; +import java.security.DigestInputStream; + +import javax.xml.bind.DatatypeConverter; public class EMADLGenerator { @@ -67,6 +69,8 @@ public class EMADLGenerator { private Backend backend; private String modelsPath; + + public EMADLGenerator(Backend backend) { this.backend = backend; @@ -82,22 +86,24 @@ public class EMADLGenerator { } public void setModelsPath(String modelsPath) { - if (!(modelsPath.substring(modelsPath.length() - 1).equals("/"))) { + if (!(modelsPath.substring(modelsPath.length() - 1).equals("/"))){ this.modelsPath = modelsPath + "/"; - } else { + } + else { this.modelsPath = modelsPath; } } - public void setGenerationTargetPath(String generationTargetPath) { - if (!(generationTargetPath.substring(generationTargetPath.length() - 1).equals("/"))) { + public void setGenerationTargetPath(String generationTargetPath){ + if (!(generationTargetPath.substring(generationTargetPath.length() - 1).equals("/"))){ getEmamGen().setGenerationTargetPath(generationTargetPath + "/"); - } else { + } + else { getEmamGen().setGenerationTargetPath(generationTargetPath); } } - public String getGenerationTargetPath() { + public String getGenerationTargetPath(){ return getEmamGen().getGenerationTargetPath(); } @@ -105,27 +111,24 @@ public class EMADLGenerator { return emamGen; } - public void generate(String modelPath, String qualifiedName, String pythonPath, String forced, boolean doCompile) - throws IOException, TemplateException { - setModelsPath(modelPath); + public void generate(String modelPath, String qualifiedName, String pythonPath, String forced, boolean doCompile) throws IOException, TemplateException { + setModelsPath( modelPath ); TaggingResolver symtab = EMADLAbstractSymtab.createSymTabAndTaggingResolver(getModelsPath()); - EMAComponentSymbol component = symtab.resolve(qualifiedName, EMAComponentSymbol.KIND) - .orElse(null); + EMAComponentSymbol component = symtab.resolve(qualifiedName, EMAComponentSymbol.KIND).orElse(null); List splitName = Splitters.DOT.splitToList(qualifiedName); String componentName = splitName.get(splitName.size() - 1); String instanceName = componentName.substring(0, 1).toLowerCase() + componentName.substring(1); - if (component == null) { + if (component == null){ Log.error("Component with name '" + componentName + "' does not exist."); System.exit(1); } - EMAComponentInstanceSymbol instance = component.getEnclosingScope() - .resolve(instanceName, EMAComponentInstanceSymbol.KIND).get(); + EMAComponentInstanceSymbol instance = component.getEnclosingScope().resolve(instanceName, EMAComponentInstanceSymbol.KIND).get(); generateFiles(symtab, instance, symtab, pythonPath, forced); - + if (doCompile) { compile(); } @@ -138,11 +141,11 @@ public class EMADLGenerator { pb.inheritIO(); Process process = pb.start(); int returnCode = process.waitFor(); - if (returnCode != 0) { + if(returnCode != 0) { Log.error("During compilation, an error occured. See above for more details."); System.exit(1); } - } catch (Exception e) { + }catch(Exception e){ Log.error("During compilation, the following error occured: '" + e.toString() + "'"); System.exit(1); } finally { @@ -150,10 +153,11 @@ public class EMADLGenerator { } } - public File createTempScript() throws IOException { + public File createTempScript() throws IOException{ File tempScript = File.createTempFile("script", null); - try { - Writer streamWriter = new OutputStreamWriter(new FileOutputStream(tempScript)); + try{ + Writer streamWriter = new OutputStreamWriter(new FileOutputStream( + tempScript)); PrintWriter printWriter = new PrintWriter(streamWriter); printWriter.println("#!/bin/bash"); @@ -164,7 +168,7 @@ public class EMADLGenerator { printWriter.println("make"); printWriter.close(); - } catch (Exception e) { + }catch(Exception e){ System.out.println(e); } @@ -187,11 +191,9 @@ public class EMADLGenerator { } } - public void generateFiles(TaggingResolver taggingResolver, EMAComponentInstanceSymbol EMAComponentSymbol, - Scope symtab, String pythonPath, String forced) throws IOException { + public void generateFiles(TaggingResolver taggingResolver, EMAComponentInstanceSymbol EMAComponentSymbol, Scope symtab, String pythonPath, String forced) throws IOException { Set allInstances = new HashSet<>(); - List fileContents = generateStrings(taggingResolver, EMAComponentSymbol, symtab, allInstances, - forced); + List fileContents = generateStrings(taggingResolver, EMAComponentSymbol, symtab, allInstances, forced); for (FileContent fileContent : fileContents) { emamGen.generateFile(fileContent); @@ -199,77 +201,74 @@ public class EMADLGenerator { // train Map fileContentMap = new HashMap<>(); - for (FileContent f : fileContents) { + for(FileContent f : fileContents) { fileContentMap.put(f.getFileName(), f.getFileContent()); } List fileContentsTrainingHashes = new ArrayList<>(); List newHashes = new ArrayList<>(); for (EMAComponentInstanceSymbol componentInstance : allInstances) { - Optional architecture = componentInstance.getSpannedScope().resolve("", - ArchitectureSymbol.KIND); + Optional architecture = componentInstance.getSpannedScope().resolve("", ArchitectureSymbol.KIND); - if (!architecture.isPresent()) { + if(!architecture.isPresent()) { continue; } - if (forced.equals("n")) { + if(forced.equals("n")) { continue; } - String configFilename = getConfigFilename(componentInstance.getComponentType().getFullName(), - componentInstance.getFullName(), componentInstance.getName()); + String configFilename = getConfigFilename(componentInstance.getComponentType().getFullName(), componentInstance.getFullName(), componentInstance.getName()); String emadlPath = getModelsPath() + configFilename + ".emadl"; String cnntPath = getModelsPath() + configFilename + ".cnnt"; String emadlHash = getChecksumForFile(emadlPath); String cnntHash = getChecksumForFile(cnntPath); - String componentConfigFilename = componentInstance.getComponentType().getReferencedSymbol().getFullName() - .replaceAll("\\.", "/"); + String componentConfigFilename = componentInstance.getComponentType().getReferencedSymbol().getFullName().replaceAll("\\.", "/"); String b = backend.getBackendString(backend); String trainingDataHash = ""; String testDataHash = ""; - if (b.equals("CAFFE2")) { + if(b.equals("CAFFE2")){ trainingDataHash = getChecksumForFile(architecture.get().getDataPath() + "/train_lmdb/data.mdb"); testDataHash = getChecksumForFile(architecture.get().getDataPath() + "/test_lmdb/data.mdb"); - } else { + }else{ trainingDataHash = getChecksumForFile(architecture.get().getDataPath() + "/train.h5"); testDataHash = getChecksumForFile(architecture.get().getDataPath() + "/test.h5"); } String trainingHash = emadlHash + "#" + cnntHash + "#" + trainingDataHash + "#" + testDataHash; - boolean alreadyTrained = newHashes.contains(trainingHash) - || isAlreadyTrained(trainingHash, componentInstance); - if (alreadyTrained && !forced.equals("y")) { + boolean alreadyTrained = newHashes.contains(trainingHash) || isAlreadyTrained(trainingHash, componentInstance); + if(alreadyTrained && !forced.equals("y")) { Log.warn("Training of model " + componentInstance.getFullName() + " skipped"); - } else { - String parsedFullName = componentInstance.getFullName().substring(0, 1).toLowerCase() - + componentInstance.getFullName().substring(1).replaceAll("\\.", "_"); + } + else { + String parsedFullName = componentInstance.getFullName().substring(0, 1).toLowerCase() + componentInstance.getFullName().substring(1).replaceAll("\\.", "_"); String trainerScriptName = "CNNTrainer_" + parsedFullName + ".py"; String trainingPath = getGenerationTargetPath() + trainerScriptName; - if (Files.exists(Paths.get(trainingPath))) { + if(Files.exists(Paths.get(trainingPath))){ ProcessBuilder pb = new ProcessBuilder(Arrays.asList(pythonPath, trainingPath)).inheritIO(); Process p = pb.start(); int exitCode = 0; try { exitCode = p.waitFor(); - } catch (InterruptedException e) { + } + catch(InterruptedException e) { Log.error("Training aborted: exit code " + Integer.toString(exitCode)); System.exit(1); } - if (exitCode != 0) { + if(exitCode != 0) { Log.error("Training failed: exit code " + Integer.toString(exitCode)); System.exit(1); } - fileContentsTrainingHashes - .add(new FileContent(trainingHash, componentConfigFilename + ".training_hash")); + fileContentsTrainingHashes.add(new FileContent(trainingHash, componentConfigFilename + ".training_hash")); newHashes.add(trainingHash); - } else { + } + else{ System.out.println("Trainingfile " + trainingPath + " not found."); } } @@ -284,7 +283,8 @@ public class EMADLGenerator { private static String convertByteArrayToHexString(byte[] arrayBytes) { StringBuffer stringBuffer = new StringBuffer(); for (int i = 0; i < arrayBytes.length; i++) { - stringBuffer.append(Integer.toString((arrayBytes[i] & 0xff) + 0x100, 16).substring(1)); + stringBuffer.append(Integer.toString((arrayBytes[i] & 0xff) + 0x100, 16) + .substring(1)); } return stringBuffer.toString(); } @@ -295,25 +295,24 @@ public class EMADLGenerator { String componentConfigFilename = component.getFullName().replaceAll("\\.", "/"); String checkFilePathString = getGenerationTargetPath() + componentConfigFilename + ".training_hash"; - Path checkFilePath = Paths.get(checkFilePathString); - if (Files.exists(checkFilePath)) { + Path checkFilePath = Paths.get( checkFilePathString); + if(Files.exists(checkFilePath)) { List hashes = Files.readAllLines(checkFilePath); - for (String hash : hashes) { - if (hash.equals(trainingHash)) { + for(String hash : hashes) { + if(hash.equals(trainingHash)) { return true; } } } return false; - } catch (Exception e) { + } + catch(Exception e) { return false; } } - public List generateStrings(TaggingResolver taggingResolver, - EMAComponentInstanceSymbol componentInstanceSymbol, Scope symtab, - Set allInstances, String forced) { + public List generateStrings(TaggingResolver taggingResolver, EMAComponentInstanceSymbol componentInstanceSymbol, Scope symtab, Set allInstances, String forced){ List fileContents = new ArrayList<>(); generateComponent(fileContents, allInstances, taggingResolver, componentInstanceSymbol, symtab); @@ -327,16 +326,14 @@ public class EMADLGenerator { if (cnnArchGenerator.isCMakeRequired()) { cnnArchGenerator.setGenerationTargetPath(getGenerationTargetPath()); - Map cmakeContentsMap = cnnArchGenerator - .generateCMakeContent(componentInstanceSymbol.getFullName()); - for (String fileName : cmakeContentsMap.keySet()) { + Map cmakeContentsMap = cnnArchGenerator.generateCMakeContent(componentInstanceSymbol.getFullName()); + for (String fileName : cmakeContentsMap.keySet()){ fileContents.add(new FileContent(cmakeContentsMap.get(fileName), fileName)); } } if (emamGen.shouldGenerateMainClass()) { - // fileContents.add(emamGen.getMainClassFileContent(componentInstanceSymbol, - // fileContents.get(0))); + //fileContents.add(emamGen.getMainClassFileContent(componentInstanceSymbol, fileContents.get(0))); } else if (emamGen.shouldGenerateSimulatorInterface()) { fileContents.addAll(SimulatorIntegrationHelper.getSimulatorIntegrationHelperFileContent()); } @@ -346,36 +343,17 @@ public class EMADLGenerator { return fileContents; } - /** - * returns data path either from tags or data_paths.txt - */ - protected String getDataPath(TaggingResolver taggingResolver, EMAComponentSymbol component, - EMAComponentInstanceSymbol instance) { - List instanceTags = new LinkedList<>(); - - boolean isChildComponent = instance.getEnclosingComponent().isPresent(); - - if (isChildComponent) { - // get all instantiated components of parent - List instantiationSymbols = (List) instance - .getEnclosingComponent().get().getComponentType().getReferencedSymbol().getSubComponents(); - - // filter corresponding instantiation of instance and add tags - instantiationSymbols.stream().filter(e -> e.getName().equals(instance.getName())).findFirst() - .ifPresent(symbol -> instanceTags.addAll(taggingResolver.getTags(symbol, DataPathSymbol.KIND))); - } - + protected String getDataPath(TaggingResolver taggingResolver, EMAComponentSymbol component, EMAComponentInstanceSymbol instance){ // instance tags have priority - List tags = !instanceTags.isEmpty() ? instanceTags - : (List) taggingResolver.getTags(component, DataPathSymbol.KIND); + List instanceTags = (List) taggingResolver.getTags(instance, DataPathSymbol.KIND); + List tags = !instanceTags.isEmpty() ? instanceTags : + (List) taggingResolver.getTags(component, DataPathSymbol.KIND); String dataPath; - if (!tags.isEmpty()) { - DataPathCocos.check(component, taggingResolver); dataPath = (String) tags.get(0).getValues().get(0); - Log.warn("Tagfile was found, ignoring data_paths.txt: " + dataPath); - } else { + } + else { DataPathConfigParser newParserConfig = new DataPathConfigParser(getModelsPath() + "data_paths.txt"); dataPath = newParserConfig.getDataPath(component.getFullName()); } @@ -383,104 +361,103 @@ public class EMADLGenerator { return dataPath; } - protected void generateComponent(List fileContents, Set allInstances, - TaggingResolver taggingResolver, EMAComponentInstanceSymbol componentInstanceSymbol, Scope symtab) { + protected void generateComponent(List fileContents, + Set allInstances, + TaggingResolver taggingResolver, + EMAComponentInstanceSymbol componentInstanceSymbol, + Scope symtab){ allInstances.add(componentInstanceSymbol); EMAComponentSymbol EMAComponentSymbol = componentInstanceSymbol.getComponentType().getReferencedSymbol(); - /* - * remove the following two lines if the component symbol full name bug with - * generic variables is fixed - */ + /* remove the following two lines if the component symbol full name bug with generic variables is fixed */ EMAComponentSymbol.setFullName(null); EMAComponentSymbol.getFullName(); /* */ - Optional architecture = componentInstanceSymbol.getSpannedScope().resolve("", - ArchitectureSymbol.KIND); - Optional mathStatements = EMAComponentSymbol.getSpannedScope().resolve("MathStatements", - MathStatementsSymbol.KIND); + Optional architecture = componentInstanceSymbol.getSpannedScope().resolve("", ArchitectureSymbol.KIND); + Optional mathStatements = EMAComponentSymbol.getSpannedScope().resolve("MathStatements", MathStatementsSymbol.KIND); EMADLCocos.checkAll(componentInstanceSymbol); - if (architecture.isPresent()) { + if (architecture.isPresent()){ String dPath = getDataPath(taggingResolver, EMAComponentSymbol, componentInstanceSymbol); architecture.get().setDataPath(dPath); architecture.get().setComponentName(EMAComponentSymbol.getFullName()); generateCNN(fileContents, taggingResolver, componentInstanceSymbol, architecture.get()); - } else if (mathStatements.isPresent()) { + } + else if (mathStatements.isPresent()){ generateMathComponent(fileContents, taggingResolver, componentInstanceSymbol, mathStatements.get()); - } else { + } + else { generateSubComponents(fileContents, allInstances, taggingResolver, componentInstanceSymbol, symtab); } } - private void fixArmadilloImports(List fileContents) { - for (FileContent fileContent : fileContents) { - fileContent.setFileContent( - fileContent.getFileContent().replaceFirst("#include \"armadillo.h\"", "#include \"armadillo\"")); + private void fixArmadilloImports(List fileContents){ + for (FileContent fileContent : fileContents){ + fileContent.setFileContent(fileContent.getFileContent() + .replaceFirst("#include \"armadillo.h\"", + "#include \"armadillo\"")); } } - public void generateCNN(List fileContents, TaggingResolver taggingResolver, - EMAComponentInstanceSymbol instance, ArchitectureSymbol architecture) { - Map contentMap = cnnArchGenerator.generateStrings(architecture); + public void generateCNN(List fileContents, TaggingResolver taggingResolver, EMAComponentInstanceSymbol instance, ArchitectureSymbol architecture){ + Map contentMap = cnnArchGenerator.generateStrings(architecture); String fullName = instance.getFullName().replaceAll("\\.", "_"); - // get the components execute method + //get the components execute method String executeKey = "execute_" + fullName; String executeMethod = contentMap.get(executeKey); - if (executeMethod == null) { + if (executeMethod == null){ throw new IllegalStateException("execute method of " + fullName + " not found"); } contentMap.remove(executeKey); String component = emamGen.generateString(taggingResolver, instance, (MathStatementsSymbol) null); FileContent componentFileContent = new FileContent( - transformComponent(component, "CNNPredictor_" + fullName, executeMethod), instance); + transformComponent(component, "CNNPredictor_" + fullName, executeMethod), + instance); - for (String fileName : contentMap.keySet()) { + for (String fileName : contentMap.keySet()){ fileContents.add(new FileContent(contentMap.get(fileName), fileName)); } fileContents.add(componentFileContent); fileContents.add(new FileContent(readResource("CNNTranslator.h", Charsets.UTF_8), "CNNTranslator.h")); } - protected String transformComponent(String component, String predictorClassName, String executeMethod) { + protected String transformComponent(String component, String predictorClassName, String executeMethod){ String networkVariableName = "_cnn_"; - // insert includes - component = component.replaceFirst("using namespace", "#include \"" + predictorClassName + ".h" + "\"\n" - + "#include \"CNNTranslator.h\"\n" + "using namespace"); + //insert includes + component = component.replaceFirst("using namespace", + "#include \"" + predictorClassName + ".h" + "\"\n" + + "#include \"CNNTranslator.h\"\n" + + "using namespace"); - // insert network attribute + //insert network attribute component = component.replaceFirst("public:", "public:\n" + predictorClassName + " " + networkVariableName + ";"); - // insert execute method + //insert execute method component = component.replaceFirst("void execute\\(\\)\\s\\{\\s\\}", "void execute(){\n" + executeMethod + "\n}"); return component; } - public void generateMathComponent(List fileContents, TaggingResolver taggingResolver, - EMAComponentInstanceSymbol EMAComponentSymbol, MathStatementsSymbol mathStatementsSymbol) { + public void generateMathComponent(List fileContents, TaggingResolver taggingResolver, EMAComponentInstanceSymbol EMAComponentSymbol, MathStatementsSymbol mathStatementsSymbol){ fileContents.add(new FileContent( - emamGen.generateString(taggingResolver, EMAComponentSymbol, mathStatementsSymbol), EMAComponentSymbol)); + emamGen.generateString(taggingResolver, EMAComponentSymbol, mathStatementsSymbol), + EMAComponentSymbol)); } - public void generateSubComponents(List fileContents, Set allInstances, - TaggingResolver taggingResolver, EMAComponentInstanceSymbol componentInstanceSymbol, Scope symtab) { - fileContents.add(new FileContent( - emamGen.generateString(taggingResolver, componentInstanceSymbol, (MathStatementsSymbol) null), - componentInstanceSymbol)); + public void generateSubComponents(List fileContents, Set allInstances, TaggingResolver taggingResolver, EMAComponentInstanceSymbol componentInstanceSymbol, Scope symtab){ + fileContents.add(new FileContent(emamGen.generateString(taggingResolver, componentInstanceSymbol, (MathStatementsSymbol) null), componentInstanceSymbol)); String lastNameWithoutArrayPart = ""; for (EMAComponentInstanceSymbol instanceSymbol : componentInstanceSymbol.getSubComponents()) { int arrayBracketIndex = instanceSymbol.getName().indexOf("["); boolean generateComponentInstance = true; if (arrayBracketIndex != -1) { - generateComponentInstance = !instanceSymbol.getName().substring(0, arrayBracketIndex) - .equals(lastNameWithoutArrayPart); + generateComponentInstance = !instanceSymbol.getName().substring(0, arrayBracketIndex).equals(lastNameWithoutArrayPart); lastNameWithoutArrayPart = instanceSymbol.getName().substring(0, arrayBracketIndex); Log.info(lastNameWithoutArrayPart, "Without:"); Log.info(generateComponentInstance + "", "Bool:"); @@ -495,45 +472,45 @@ public class EMADLGenerator { String trainConfigFilename; String mainComponentConfigFilename = mainComponentName.replaceAll("\\.", "/"); String componentConfigFilename = componentFullName.replaceAll("\\.", "/"); - String instanceConfigFilename = componentFullName.replaceAll("\\.", "/") + "_" + componentName; - if (Files.exists(Paths.get(getModelsPath() + instanceConfigFilename + ".cnnt"))) { + String instanceConfigFilename = componentFullName.replaceAll("\\.", "/") + "_" + componentName; + if (Files.exists(Paths.get( getModelsPath() + instanceConfigFilename + ".cnnt"))) { trainConfigFilename = instanceConfigFilename; - } else if (Files.exists(Paths.get(getModelsPath() + componentConfigFilename + ".cnnt"))) { + } + else if (Files.exists(Paths.get( getModelsPath() + componentConfigFilename + ".cnnt"))){ trainConfigFilename = componentConfigFilename; - } else if (Files.exists(Paths.get(getModelsPath() + mainComponentConfigFilename + ".cnnt"))) { + } + else if (Files.exists(Paths.get( getModelsPath() + mainComponentConfigFilename + ".cnnt"))){ trainConfigFilename = mainComponentConfigFilename; - } else { - Log.error("Missing configuration file. " - + "Could not find a file with any of the following names (only one needed): '" + getModelsPath() - + instanceConfigFilename + ".cnnt', '" + getModelsPath() + componentConfigFilename + ".cnnt', '" - + getModelsPath() + mainComponentConfigFilename + ".cnnt'." - + " These files denote respectively the configuration for the single instance, the component or the whole system."); + } + else{ + Log.error("Missing configuration file. " + + "Could not find a file with any of the following names (only one needed): '" + + getModelsPath() + instanceConfigFilename + ".cnnt', '" + + getModelsPath() + componentConfigFilename + ".cnnt', '" + + getModelsPath() + mainComponentConfigFilename + ".cnnt'." + + " These files denote respectively the configuration for the single instance, the component or the whole system."); return null; } return trainConfigFilename; } - public List generateCNNTrainer(Set allInstances, - String mainComponentName) { + public List generateCNNTrainer(Set allInstances, String mainComponentName) { List fileContents = new ArrayList<>(); for (EMAComponentInstanceSymbol componentInstance : allInstances) { EMAComponentSymbol component = componentInstance.getComponentType().getReferencedSymbol(); - Optional architecture = component.getSpannedScope().resolve("", - ArchitectureSymbol.KIND); + Optional architecture = component.getSpannedScope().resolve("", ArchitectureSymbol.KIND); if (architecture.isPresent()) { - String trainConfigFilename = getConfigFilename(mainComponentName, component.getFullName(), - component.getName()); + String trainConfigFilename = getConfigFilename(mainComponentName, component.getFullName(), component.getName()); - // should be removed when CNNTrain supports packages + //should be removed when CNNTrain supports packages List names = Splitter.on("/").splitToList(trainConfigFilename); - trainConfigFilename = names.get(names.size() - 1); - Path modelPath = Paths.get(getModelsPath() + Joiner.on("/").join(names.subList(0, names.size() - 1))); - ConfigurationSymbol configuration = cnnTrainGenerator.getConfigurationSymbol(modelPath, - trainConfigFilename); + trainConfigFilename = names.get(names.size()-1); + Path modelPath = Paths.get(getModelsPath() + Joiner.on("/").join(names.subList(0,names.size()-1))); + ConfigurationSymbol configuration = cnnTrainGenerator.getConfigurationSymbol(modelPath, trainConfigFilename); cnnTrainGenerator.setInstanceName(componentInstance.getFullName().replaceAll("\\.", "_")); - Map fileContentMap = cnnTrainGenerator.generateStrings(configuration); - for (String fileName : fileContentMap.keySet()) { + Map fileContentMap = cnnTrainGenerator.generateStrings(configuration); + for (String fileName : fileContentMap.keySet()){ fileContents.add(new FileContent(fileContentMap.get(fileName), fileName)); } } diff --git a/src/test/java/de/monticore/lang/monticar/emadl/GenerationTest.java b/src/test/java/de/monticore/lang/monticar/emadl/GenerationTest.java index fe6593be55f668a4c64ad06bbc72e2fdedd3af05..df5b87acd20f356cb8cdd3afdb1e03dac13a9c2b 100644 --- a/src/test/java/de/monticore/lang/monticar/emadl/GenerationTest.java +++ b/src/test/java/de/monticore/lang/monticar/emadl/GenerationTest.java @@ -37,7 +37,6 @@ import java.util.Arrays; import java.util.List; import static junit.framework.TestCase.assertTrue; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; public class GenerationTest extends AbstractSymtabTest { @@ -52,24 +51,29 @@ public class GenerationTest extends AbstractSymtabTest { @Test public void testCifar10Generation() throws IOException, TemplateException { Log.getFindings().clear(); - String[] args = { "-m", "src/test/resources/models/", "-r", "cifar10.Cifar10Classifier", "-b", "MXNET", "-f", - "n", "-c", "n" }; + String[] args = {"-m", "src/test/resources/models/", "-r", "cifar10.Cifar10Classifier", "-b", "MXNET", "-f", "n", "-c", "n"}; EMADLGeneratorCli.main(args); assertTrue(Log.getFindings().isEmpty()); - checkFilesAreEqual(Paths.get("./target/generated-sources-emadl"), Paths.get("./src/test/resources/target_code"), - Arrays.asList("cifar10_cifar10Classifier.cpp", "cifar10_cifar10Classifier.h", - "CNNCreator_cifar10_cifar10Classifier_net.py", "CNNBufferFile.h", - "CNNPredictor_cifar10_cifar10Classifier_net.h", "cifar10_cifar10Classifier_net.h", - "CNNTranslator.h", "cifar10_cifar10Classifier_calculateClass.h", + checkFilesAreEqual( + Paths.get("./target/generated-sources-emadl"), + Paths.get("./src/test/resources/target_code"), + Arrays.asList( + "cifar10_cifar10Classifier.cpp", + "cifar10_cifar10Classifier.h", + "CNNCreator_cifar10_cifar10Classifier_net.py", + "CNNBufferFile.h", + "CNNPredictor_cifar10_cifar10Classifier_net.h", + "cifar10_cifar10Classifier_net.h", + "CNNTranslator.h", + "cifar10_cifar10Classifier_calculateClass.h", "CNNTrainer_cifar10_cifar10Classifier_net.py")); } @Test public void testSimulatorGeneration() throws IOException, TemplateException { Log.getFindings().clear(); - String[] args = { "-m", "src/test/resources/models/", "-r", "simulator.MainController", "-b", "MXNET", "-f", - "n", "-c", "n" }; + String[] args = {"-m", "src/test/resources/models/", "-r", "simulator.MainController", "-b", "MXNET", "-f", "n", "-c", "n"}; EMADLGeneratorCli.main(args); assertTrue(Log.getFindings().isEmpty()); } @@ -77,7 +81,7 @@ public class GenerationTest extends AbstractSymtabTest { @Test public void testAddGeneration() throws IOException, TemplateException { Log.getFindings().clear(); - String[] args = { "-m", "src/test/resources/models/", "-r", "Add", "-b", "MXNET", "-f", "n", "-c", "n" }; + String[] args = {"-m", "src/test/resources/models/", "-r", "Add", "-b", "MXNET", "-f", "n", "-c", "n"}; EMADLGeneratorCli.main(args); assertTrue(Log.getFindings().isEmpty()); } @@ -85,7 +89,7 @@ public class GenerationTest extends AbstractSymtabTest { @Test public void testAlexnetGeneration() throws IOException, TemplateException { Log.getFindings().clear(); - String[] args = { "-m", "src/test/resources/models/", "-r", "Alexnet", "-b", "MXNET", "-f", "n", "-c", "n" }; + String[] args = {"-m", "src/test/resources/models/", "-r", "tagging.Parent", "-b", "MXNET", "-f", "n", "-c", "n"}; EMADLGeneratorCli.main(args); assertTrue(Log.getFindings().isEmpty()); } @@ -93,7 +97,7 @@ public class GenerationTest extends AbstractSymtabTest { @Test public void testResNeXtGeneration() throws IOException, TemplateException { Log.getFindings().clear(); - String[] args = { "-m", "src/test/resources/models/", "-r", "ResNeXt50", "-b", "MXNET", "-f", "n", "-c", "n" }; + String[] args = {"-m", "src/test/resources/models/", "-r", "ResNeXt50", "-b", "MXNET", "-f", "n", "-c", "n"}; EMADLGeneratorCli.main(args); assertTrue(Log.getFindings().isEmpty()); } @@ -101,8 +105,7 @@ public class GenerationTest extends AbstractSymtabTest { @Test public void testThreeInputGeneration() throws IOException, TemplateException { Log.getFindings().clear(); - String[] args = { "-m", "src/test/resources/models/", "-r", "ThreeInputCNN_M14", "-b", "MXNET", "-f", "n", "-c", - "n" }; + String[] args = {"-m", "src/test/resources/models/", "-r", "ThreeInputCNN_M14", "-b", "MXNET", "-f", "n", "-c", "n"}; EMADLGeneratorCli.main(args); assertTrue(Log.getFindings().size() == 1); } @@ -110,8 +113,7 @@ public class GenerationTest extends AbstractSymtabTest { @Test public void testMultipleOutputsGeneration() throws IOException, TemplateException { Log.getFindings().clear(); - String[] args = { "-m", "src/test/resources/models/", "-r", "MultipleOutputs", "-b", "MXNET", "-f", "n", "-c", - "n" }; + String[] args = {"-m", "src/test/resources/models/", "-r", "MultipleOutputs", "-b", "MXNET", "-f", "n", "-c", "n"}; EMADLGeneratorCli.main(args); assertTrue(Log.getFindings().size() == 1); } @@ -119,7 +121,7 @@ public class GenerationTest extends AbstractSymtabTest { @Test public void testVGGGeneration() throws IOException, TemplateException { Log.getFindings().clear(); - String[] args = { "-m", "src/test/resources/models/", "-r", "VGG16", "-b", "MXNET", "-f", "n", "-c", "n" }; + String[] args = {"-m", "src/test/resources/models/", "-r", "VGG16", "-b", "MXNET", "-f", "n", "-c", "n"}; EMADLGeneratorCli.main(args); assertTrue(Log.getFindings().isEmpty()); } @@ -128,11 +130,11 @@ public class GenerationTest extends AbstractSymtabTest { public void testMultipleInstances() throws IOException, TemplateException { try { Log.getFindings().clear(); - String[] args = { "-m", "src/test/resources/models/", "-r", "InstanceTest.MainB", "-b", "MXNET", "-f", "n", - "-c", "n" }; + String[] args = {"-m", "src/test/resources/models/", "-r", "InstanceTest.MainB", "-b", "MXNET", "-f", "n", "-c", "n"}; EMADLGeneratorCli.main(args); assertTrue(Log.getFindings().isEmpty()); - } catch (Exception e) { + } + catch(Exception e) { e.printStackTrace(); } } @@ -140,86 +142,59 @@ public class GenerationTest extends AbstractSymtabTest { @Test public void testMnistClassifier() throws IOException, TemplateException { Log.getFindings().clear(); - String[] args = { "-m", "src/test/resources/models/", "-r", "mnist.MnistClassifier", "-b", "CAFFE2", "-f", "n", - "-c", "n" }; + String[] args = {"-m", "src/test/resources/models/", "-r", "mnist.MnistClassifier", "-b", "CAFFE2", "-f", "n", "-c", "n"}; EMADLGeneratorCli.main(args); assertTrue(Log.getFindings().isEmpty()); - checkFilesAreEqual(Paths.get("./target/generated-sources-emadl"), Paths.get("./src/test/resources/target_code"), - Arrays.asList("mnist_mnistClassifier.cpp", "mnist_mnistClassifier.h", - "CNNCreator_mnist_mnistClassifier_net.py", "CNNPredictor_mnist_mnistClassifier_net.h", - "mnist_mnistClassifier_net.h", "CNNTranslator.h", "mnist_mnistClassifier_calculateClass.h", + checkFilesAreEqual( + Paths.get("./target/generated-sources-emadl"), + Paths.get("./src/test/resources/target_code"), + Arrays.asList( + "mnist_mnistClassifier.cpp", + "mnist_mnistClassifier.h", + "CNNCreator_mnist_mnistClassifier_net.py", + "CNNPredictor_mnist_mnistClassifier_net.h", + "mnist_mnistClassifier_net.h", + "CNNTranslator.h", + "mnist_mnistClassifier_calculateClass.h", "CNNTrainer_mnist_mnistClassifier_net.py")); } @Test public void testMnistClassifierForGluon() throws IOException, TemplateException { Log.getFindings().clear(); - String[] args = { "-m", "src/test/resources/models/", "-r", "mnist.MnistClassifier", "-b", "GLUON", "-f", "n", - "-c", "n" }; + String[] args = {"-m", "src/test/resources/models/", "-r", "mnist.MnistClassifier", "-b", "GLUON", "-f", "n", "-c", "n"}; EMADLGeneratorCli.main(args); assertTrue(Log.getFindings().isEmpty()); - checkFilesAreEqual(Paths.get("./target/generated-sources-emadl"), + checkFilesAreEqual( + Paths.get("./target/generated-sources-emadl"), Paths.get("./src/test/resources/target_code/gluon"), - Arrays.asList("CNNBufferFile.h", "CNNNet_mnist_mnistClassifier_net.py", "mnist_mnistClassifier.cpp", - "mnist_mnistClassifier.h", "CNNCreator_mnist_mnistClassifier_net.py", - "CNNPredictor_mnist_mnistClassifier_net.h", "CNNDataLoader_mnist_mnistClassifier_net.py", - "supervised_trainer.py", "mnist_mnistClassifier_net.h", "HelperA.h", "CNNTranslator.h", - "mnist_mnistClassifier_calculateClass.h", "CNNTrainer_mnist_mnistClassifier_net.py", + Arrays.asList( + "CNNBufferFile.h", + "CNNNet_mnist_mnistClassifier_net.py", + "mnist_mnistClassifier.cpp", + "mnist_mnistClassifier.h", + "CNNCreator_mnist_mnistClassifier_net.py", + "CNNPredictor_mnist_mnistClassifier_net.h", + "CNNDataLoader_mnist_mnistClassifier_net.py", + "supervised_trainer.py", + "mnist_mnistClassifier_net.h", + "HelperA.h", + "CNNTranslator.h", + "mnist_mnistClassifier_calculateClass.h", + "CNNTrainer_mnist_mnistClassifier_net.py", "mnist_mnistClassifier_net.h")); } @Test public void testHashFunction() { EMADLGenerator tester = new EMADLGenerator(Backend.MXNET); - - try { + + try{ tester.getChecksumForFile("invalid Path!"); assertTrue("Hash method should throw IOException on invalid path", false); - } catch (IOException e) { + } catch(IOException e){ } } - - @Test - public void testAlexNetTagging() { - Log.getFindings().clear(); - String[] args = { "-m", "src/test/resources/models/", "-r", "tagging.Alexnet", "-b", "MXNET", "-f", "n", "-c", - "n" }; - EMADLGeneratorCli.main(args); - assertEquals(Log.getFindings().size(), 1); - assertEquals(Log.getFindings().get(0).toString(), - "Tagfile was found, ignoring data_paths.txt: src/test/resources/models"); - assertTrue(Log.getErrorCount() == 0); - } - - @Test - public void testInvalidPathCoCos() { - Log.getFindings().clear(); - String[] args = { "-m", "src/test/resources/models/", "-r", "tagging.AlexnetInvalid", "-b", "MXNET", "-f", "n", - "-c", "n" }; - EMADLGeneratorCli.main(args); - assertEquals(Log.getFindings().size(), 3); - assertTrue( - Log.getFindings().get(0).toString().matches("Filepath '(.)*/test/resources/models' does not exist!")); - assertEquals(Log.getFindings().get(1).toString(), "DatapathType is incorrect, must be of Type: HDF5 or LMDB"); - assertEquals(Log.getFindings().get(2).toString(), - "Tagfile was found, ignoring data_paths.txt: test/resources/models"); - - assertTrue(Log.getErrorCount() == 0); - } - - @Test - public void testInvalidTypeCocos() { - Log.getFindings().clear(); - String[] args = { "-m", "src/test/resources/models/", "-r", "tagging.AlexnetInvalidType", "-b", "MXNET", "-f", - "n", "-c", "n" }; - EMADLGeneratorCli.main(args); - assertEquals(Log.getFindings().size(), 2); - assertEquals(Log.getFindings().get(0).toString(), "DatapathType is incorrect, must be of Type: HDF5 or LMDB"); - assertEquals(Log.getFindings().get(1).toString(), - "Tagfile was found, ignoring data_paths.txt: src/test/resources/models"); - - assertTrue(Log.getErrorCount() == 0); - } } diff --git a/src/test/resources/models/tagging/AlexNet.tag b/src/test/resources/models/tagging/AlexNet.tag index 9a232fec4de3369d84c973fe960c22425b05424a..d6eccd6a5f4e00a1511ea4917d81fbc3526dbf96 100644 --- a/src/test/resources/models/tagging/AlexNet.tag +++ b/src/test/resources/models/tagging/AlexNet.tag @@ -1,12 +1,7 @@ package tagging; conforms to dltag.DataPathTagSchema; -tags Alexnet { -tag Alexnet with DataPath = {path = src/test/resources/models, type = LMDB}; -tag AlexnetInvalid with DataPath = {path = test/resources/models, type = random}; -tag AlexnetInvalidType with DataPath = {path = src/test/resources/models, type = LMBD}; -tag Parent.a1 with DataPath = {path = instanceA1, type = random}; -} - - - +tags AlexNet { +tag Alexnet with DataPath = {path = dataAlexnetComponent, type = random}; +tag Parent.a1 with DataPath = {path = dataParent, type = random}; +} \ No newline at end of file diff --git a/src/test/resources/models/tagging/Parent.emadl b/src/test/resources/models/tagging/Parent.emadl index 982ae00e9de265ac278d6ec91150173d2037063b..a4fbeb9af529590e74d97f1de06cb5eed4dd282e 100644 --- a/src/test/resources/models/tagging/Parent.emadl +++ b/src/test/resources/models/tagging/Parent.emadl @@ -1,6 +1,5 @@ package tagging; -component Parent{ - instance Alexnet a1; - instance Alexnet a2; +component Parent { + instance tagging.Alexnet a1; } \ No newline at end of file