Commit 1630781c authored by nilsfreyer's avatar nilsfreyer

adapted Tests

parent 67c5ec06
Pipeline #101628 failed with stages
in 27 seconds
......@@ -26,6 +26,7 @@ import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
import de.monticore.lang.monticar.cnnarch.DataPathConfigParser;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cmake.CMakeConfig;
import de.monticore.lang.monticar.generator.cmake.CMakeFindModule;
......@@ -43,6 +44,7 @@ import java.util.Optional;
public class CNNArch2MxNet implements CNNArchGenerator {
private String generationTargetPath;
private String modelPath;
public CNNArch2MxNet() {
setGenerationTargetPath("./target/generated-sources-cnnarch/");
......@@ -53,6 +55,14 @@ public class CNNArch2MxNet implements CNNArchGenerator {
return true;
}
public String getModelPath(){
return modelPath;
}
public void setModelPath(Path modelPath){
this.modelPath = modelPath.toString();
}
public String getGenerationTargetPath() {
if (generationTargetPath.charAt(generationTargetPath.length() - 1) != '/') {
this.generationTargetPath = generationTargetPath + "/";
......@@ -67,6 +77,7 @@ public class CNNArch2MxNet implements CNNArchGenerator {
public void generate(Path modelsDirPath, String rootModelName){
final ModelPath mp = new ModelPath(modelsDirPath);
GlobalScope scope = new GlobalScope(mp, new CNNArchLanguage());
setModelPath(modelsDirPath);
generate(scope, rootModelName);
}
......@@ -80,7 +91,12 @@ public class CNNArch2MxNet implements CNNArchGenerator {
CNNArchCocos.checkAll(compilationUnit.get());
try{
compilationUnit.get().getArchitecture().setDataPath("Temporary - read the correct data path from the config!");
String confPath = getModelPath() + "/data_paths.txt";
System.out.println(confPath);
String dataPath = DataPathConfigParser.getDataPath(confPath , rootModelName);
System.out.println(dataPath);
compilationUnit.get().getArchitecture().setDataPath(dataPath);
compilationUnit.get().getArchitecture().setComponentName(rootModelName);
generateFiles(compilationUnit.get().getArchitecture());
}
catch (IOException e){
......
Alexnet data/Alexnet
CifarClassifierNetwork data/CifarClassifierNetwork
VGG16 data/VGG16
\ No newline at end of file
Alexnet data/Alexnet
CifarClassifierNetwork data/CifarClassifierNetwork
VGG16 data/VGG16
\ No newline at end of file
......@@ -20,7 +20,7 @@ class CNNCreator_Alexnet:
module = None
_data_dir_ = "data/Alexnet/"
_model_dir_ = "model/Alexnet/"
_model_prefix_ = "Alexnet"
_model_prefix_ = "model"
_input_names_ = ['data']
_input_shapes_ = [(3,224,224)]
_output_names_ = ['predictions_label']
......
......@@ -20,7 +20,7 @@ class CNNCreator_CifarClassifierNetwork:
module = None
_data_dir_ = "data/CifarClassifierNetwork/"
_model_dir_ = "model/CifarClassifierNetwork/"
_model_prefix_ = "CifarClassifierNetwork"
_model_prefix_ = "model"
_input_names_ = ['data']
_input_shapes_ = [(3,32,32)]
_output_names_ = ['softmax_label']
......
......@@ -20,7 +20,7 @@ class CNNCreator_VGG16:
module = None
_data_dir_ = "data/VGG16/"
_model_dir_ = "model/VGG16/"
_model_prefix_ = "VGG16"
_model_prefix_ = "model"
_input_names_ = ['data']
_input_shapes_ = [(3,224,224)]
_output_names_ = ['predictions_label']
......
......@@ -11,8 +11,8 @@
class CNNPredictor_Alexnet{
public:
const std::string json_file = "model/Alexnet/Alexnet_newest-symbol.json";
const std::string param_file = "model/Alexnet/Alexnet_newest-0000.params";
const std::string json_file = "model/Alexnet/model_newest-symbol.json";
const std::string param_file = "model/Alexnet/model_newest-0000.params";
//const std::vector<std::string> input_keys = {"data"};
const std::vector<std::string> input_keys = {"data"};
const std::vector<std::vector<mx_uint>> input_shapes = {{1,3,224,224}};
......
......@@ -11,8 +11,8 @@
class CNNPredictor_CifarClassifierNetwork{
public:
const std::string json_file = "model/CifarClassifierNetwork/CifarClassifierNetwork_newest-symbol.json";
const std::string param_file = "model/CifarClassifierNetwork/CifarClassifierNetwork_newest-0000.params";
const std::string json_file = "model/CifarClassifierNetwork/model_newest-symbol.json";
const std::string param_file = "model/CifarClassifierNetwork/model_newest-0000.params";
//const std::vector<std::string> input_keys = {"data"};
const std::vector<std::string> input_keys = {"data"};
const std::vector<std::vector<mx_uint>> input_shapes = {{1,3,32,32}};
......
......@@ -11,8 +11,8 @@
class CNNPredictor_VGG16{
public:
const std::string json_file = "model/VGG16/VGG16_newest-symbol.json";
const std::string param_file = "model/VGG16/VGG16_newest-0000.params";
const std::string json_file = "model/VGG16/model_newest-symbol.json";
const std::string param_file = "model/VGG16/model_newest-0000.params";
//const std::vector<std::string> input_keys = {"data"};
const std::vector<std::string> input_keys = {"data"};
const std::vector<std::vector<mx_uint>> input_shapes = {{1,3,224,224}};
......
Alexnet data/Alexnet
CifarClassifierNetwork data/CifarClassifierNetwork
VGG16 data/VGG16
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment