Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2MXNet
Commits
1630781c
Commit
1630781c
authored
Jan 31, 2019
by
nilsfreyer
Browse files
adapted Tests
parent
67c5ec06
Pipeline
#101628
failed with stages
in 27 seconds
Changes
10
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNArch2MxNet.java
View file @
1630781c
...
...
@@ -26,6 +26,7 @@ import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage
;
import
de.monticore.lang.monticar.cnnarch.DataPathConfigParser
;
import
de.monticore.lang.monticar.generator.FileContent
;
import
de.monticore.lang.monticar.generator.cmake.CMakeConfig
;
import
de.monticore.lang.monticar.generator.cmake.CMakeFindModule
;
...
...
@@ -43,6 +44,7 @@ import java.util.Optional;
public
class
CNNArch2MxNet
implements
CNNArchGenerator
{
private
String
generationTargetPath
;
private
String
modelPath
;
public
CNNArch2MxNet
()
{
setGenerationTargetPath
(
"./target/generated-sources-cnnarch/"
);
...
...
@@ -53,6 +55,14 @@ public class CNNArch2MxNet implements CNNArchGenerator {
return
true
;
}
public
String
getModelPath
(){
return
modelPath
;
}
public
void
setModelPath
(
Path
modelPath
){
this
.
modelPath
=
modelPath
.
toString
();
}
public
String
getGenerationTargetPath
()
{
if
(
generationTargetPath
.
charAt
(
generationTargetPath
.
length
()
-
1
)
!=
'/'
)
{
this
.
generationTargetPath
=
generationTargetPath
+
"/"
;
...
...
@@ -67,6 +77,7 @@ public class CNNArch2MxNet implements CNNArchGenerator {
public
void
generate
(
Path
modelsDirPath
,
String
rootModelName
){
final
ModelPath
mp
=
new
ModelPath
(
modelsDirPath
);
GlobalScope
scope
=
new
GlobalScope
(
mp
,
new
CNNArchLanguage
());
setModelPath
(
modelsDirPath
);
generate
(
scope
,
rootModelName
);
}
...
...
@@ -80,7 +91,12 @@ public class CNNArch2MxNet implements CNNArchGenerator {
CNNArchCocos
.
checkAll
(
compilationUnit
.
get
());
try
{
compilationUnit
.
get
().
getArchitecture
().
setDataPath
(
"Temporary - read the correct data path from the config!"
);
String
confPath
=
getModelPath
()
+
"/data_paths.txt"
;
System
.
out
.
println
(
confPath
);
String
dataPath
=
DataPathConfigParser
.
getDataPath
(
confPath
,
rootModelName
);
System
.
out
.
println
(
dataPath
);
compilationUnit
.
get
().
getArchitecture
().
setDataPath
(
dataPath
);
compilationUnit
.
get
().
getArchitecture
().
setComponentName
(
rootModelName
);
generateFiles
(
compilationUnit
.
get
().
getArchitecture
());
}
catch
(
IOException
e
){
...
...
src/test/resources/architectures/data_paths.txt
0 → 100644
View file @
1630781c
Alexnet data/Alexnet
CifarClassifierNetwork data/CifarClassifierNetwork
VGG16 data/VGG16
\ No newline at end of file
src/test/resources/invalid_tests/data_paths.txt
0 → 100644
View file @
1630781c
Alexnet data/Alexnet
CifarClassifierNetwork data/CifarClassifierNetwork
VGG16 data/VGG16
\ No newline at end of file
src/test/resources/target_code/CNNCreator_Alexnet.py
View file @
1630781c
...
...
@@ -20,7 +20,7 @@ class CNNCreator_Alexnet:
module
=
None
_data_dir_
=
"data/Alexnet/"
_model_dir_
=
"model/Alexnet/"
_model_prefix_
=
"
Alexnet
"
_model_prefix_
=
"
model
"
_input_names_
=
[
'data'
]
_input_shapes_
=
[(
3
,
224
,
224
)]
_output_names_
=
[
'predictions_label'
]
...
...
src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py
View file @
1630781c
...
...
@@ -20,7 +20,7 @@ class CNNCreator_CifarClassifierNetwork:
module
=
None
_data_dir_
=
"data/CifarClassifierNetwork/"
_model_dir_
=
"model/CifarClassifierNetwork/"
_model_prefix_
=
"
CifarClassifierNetwork
"
_model_prefix_
=
"
model
"
_input_names_
=
[
'data'
]
_input_shapes_
=
[(
3
,
32
,
32
)]
_output_names_
=
[
'softmax_label'
]
...
...
src/test/resources/target_code/CNNCreator_VGG16.py
View file @
1630781c
...
...
@@ -20,7 +20,7 @@ class CNNCreator_VGG16:
module
=
None
_data_dir_
=
"data/VGG16/"
_model_dir_
=
"model/VGG16/"
_model_prefix_
=
"
VGG16
"
_model_prefix_
=
"
model
"
_input_names_
=
[
'data'
]
_input_shapes_
=
[(
3
,
224
,
224
)]
_output_names_
=
[
'predictions_label'
]
...
...
src/test/resources/target_code/CNNPredictor_Alexnet.h
View file @
1630781c
...
...
@@ -11,8 +11,8 @@
class
CNNPredictor_Alexnet
{
public:
const
std
::
string
json_file
=
"model/Alexnet/
Alexnet
_newest-symbol.json"
;
const
std
::
string
param_file
=
"model/Alexnet/
Alexnet
_newest-0000.params"
;
const
std
::
string
json_file
=
"model/Alexnet/
model
_newest-symbol.json"
;
const
std
::
string
param_file
=
"model/Alexnet/
model
_newest-0000.params"
;
//const std::vector<std::string> input_keys = {"data"};
const
std
::
vector
<
std
::
string
>
input_keys
=
{
"data"
};
const
std
::
vector
<
std
::
vector
<
mx_uint
>>
input_shapes
=
{{
1
,
3
,
224
,
224
}};
...
...
src/test/resources/target_code/CNNPredictor_CifarClassifierNetwork.h
View file @
1630781c
...
...
@@ -11,8 +11,8 @@
class
CNNPredictor_CifarClassifierNetwork
{
public:
const
std
::
string
json_file
=
"model/CifarClassifierNetwork/
CifarClassifierNetwork
_newest-symbol.json"
;
const
std
::
string
param_file
=
"model/CifarClassifierNetwork/
CifarClassifierNetwork
_newest-0000.params"
;
const
std
::
string
json_file
=
"model/CifarClassifierNetwork/
model
_newest-symbol.json"
;
const
std
::
string
param_file
=
"model/CifarClassifierNetwork/
model
_newest-0000.params"
;
//const std::vector<std::string> input_keys = {"data"};
const
std
::
vector
<
std
::
string
>
input_keys
=
{
"data"
};
const
std
::
vector
<
std
::
vector
<
mx_uint
>>
input_shapes
=
{{
1
,
3
,
32
,
32
}};
...
...
src/test/resources/target_code/CNNPredictor_VGG16.h
View file @
1630781c
...
...
@@ -11,8 +11,8 @@
class
CNNPredictor_VGG16
{
public:
const
std
::
string
json_file
=
"model/VGG16/
VGG16
_newest-symbol.json"
;
const
std
::
string
param_file
=
"model/VGG16/
VGG16
_newest-0000.params"
;
const
std
::
string
json_file
=
"model/VGG16/
model
_newest-symbol.json"
;
const
std
::
string
param_file
=
"model/VGG16/
model
_newest-0000.params"
;
//const std::vector<std::string> input_keys = {"data"};
const
std
::
vector
<
std
::
string
>
input_keys
=
{
"data"
};
const
std
::
vector
<
std
::
vector
<
mx_uint
>>
input_shapes
=
{{
1
,
3
,
224
,
224
}};
...
...
src/test/resources/valid_tests/data_paths.txt
0 → 100644
View file @
1630781c
Alexnet data/Alexnet
CifarClassifierNetwork data/CifarClassifierNetwork
VGG16 data/VGG16
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment