diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNArchTemplateController.java b/src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNArchTemplateController.java index 2454b24d86aba5df28b03766ad19a41c9e0d7c47..31bf0463653368f622bb689d69721f7cffff3681 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNArchTemplateController.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNArchTemplateController.java @@ -127,6 +127,10 @@ public class CNNArchTemplateController { return list; } + public String getComponentName(){ + return getArchitecture().getComponentName(); + } + public void include(String relativePath, String templateWithoutFileEnding, Writer writer){ String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING; Map ftlContext = new HashMap<>(); diff --git a/src/main/resources/templates/mxnet/CNNCreator.ftl b/src/main/resources/templates/mxnet/CNNCreator.ftl index 5e0ead0d70d048207b177e5235467ed2c85d59b5..5a0c600a2b0124ddc844b4dfeb8883511893aa42 100644 --- a/src/main/resources/templates/mxnet/CNNCreator.ftl +++ b/src/main/resources/templates/mxnet/CNNCreator.ftl @@ -19,8 +19,8 @@ class ${tc.fileNameWithoutEnding}: module = None _data_dir_ = "${tc.dataPath}/" - _model_dir_ = "model/${tc.fullArchitectureName}/" - _model_prefix_ = "${tc.architectureName}" + _model_dir_ = "model/${tc.componentName}/" + _model_prefix_ = "model" _input_names_ = [${tc.join(tc.architectureInputs, ",", "'", "'")}] _input_shapes_ = [<#list tc.architecture.inputs as input>(${tc.join(input.definition.type.dimensions, ",")})] _output_names_ = [${tc.join(tc.architectureOutputs, ",", "'", "_label'")}] diff --git a/src/main/resources/templates/mxnet/CNNPredictor.ftl b/src/main/resources/templates/mxnet/CNNPredictor.ftl index 3fe252d7ab8dc0132b48f1db38a10442d5bbef6c..283ed874833e8a0ed1738a5d1cf256fe59c81208 100644 --- a/src/main/resources/templates/mxnet/CNNPredictor.ftl +++ b/src/main/resources/templates/mxnet/CNNPredictor.ftl @@ -11,8 +11,8 @@ class ${tc.fileNameWithoutEnding}{ public: - const std::string json_file = "model/${tc.fullArchitectureName}/${tc.architectureName}_newest-symbol.json"; - const std::string param_file = "model/${tc.fullArchitectureName}/${tc.architectureName}_newest-0000.params"; + const std::string json_file = "model/${tc.componentName}/model_newest-symbol.json"; + const std::string param_file = "model/${tc.componentName}/model_newest-0000.params"; //const std::vector input_keys = {"data"}; const std::vector input_keys = {${tc.join(tc.architectureInputs, ",", "\"", "\"")}}; const std::vector> input_shapes = {<#list tc.architecture.inputs as input>{1,${tc.join(input.definition.type.dimensions, ",")}}<#if input?has_next>,};