Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
generators
EMADL2CPP
Commits
9e00e9b6
Commit
9e00e9b6
authored
Sep 07, 2018
by
Svetlana Pavlitskaya
Committed by
Evgeny Kusmenko
Sep 07, 2018
Browse files
Using updated CNNArchMXNet, refactoring of code related to CNNTrain
parent
7111056a
Changes
6
Hide whitespace changes
Inline
Side-by-side
pom.xml
View file @
9e00e9b6
...
...
@@ -8,7 +8,7 @@
<groupId>
de.monticore.lang.monticar
</groupId>
<artifactId>
embedded-montiarc-emadl-generator
</artifactId>
<version>
0.2.
3
-SNAPSHOT
</version>
<version>
0.2.
4
-SNAPSHOT
</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
...
...
@@ -16,8 +16,8 @@
<!-- .. SE-Libraries .................................................. -->
<emadl.version>
0.2.2-SNAPSHOT
</emadl.version>
<CNNTrain.version>
0.2.
4
-SNAPSHOT
</CNNTrain.version>
<cnnarch-mxnet-generator.version>
0.2.
4
-SNAPSHOT
</cnnarch-mxnet-generator.version>
<CNNTrain.version>
0.2.
5
-SNAPSHOT
</CNNTrain.version>
<cnnarch-mxnet-generator.version>
0.2.
5
-SNAPSHOT
</cnnarch-mxnet-generator.version>
<cnnarch-caffe2-generator.version>
0.2.2-SNAPSHOT
</cnnarch-caffe2-generator.version>
<embedded-montiarc-math-generator>
0.0.25-SNAPSHOT
</embedded-montiarc-math-generator>
...
...
src/main/java/de/monticore/lang/monticar/emadl/generator/Backend.java
View file @
9e00e9b6
...
...
@@ -4,24 +4,35 @@ package de.monticore.lang.monticar.emadl.generator;
import
de.monticore.lang.monticar.cnnarch.CNNArchGenerator
;
import
de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArch2MxNet
;
import
de.monticore.lang.monticar.cnnarch.caffe2generator.CNNArch2Caffe2
;
import
de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNTrain2MxNet
;
import
de.monticore.lang.monticar.cnntrain.CNNTrainGenerator
;
import
java.util.Optional
;
public
enum
Backend
{
MXNET
{
@Override
public
CNNArchGenerator
getGenerator
()
{
public
CNNArchGenerator
get
CNNArch
Generator
()
{
return
new
CNNArch2MxNet
();
}
@Override
public
CNNTrainGenerator
getCNNTrainGenerator
()
{
return
new
CNNTrain2MxNet
();
}
},
CAFFE2
{
@Override
public
CNNArchGenerator
getGenerator
()
{
public
CNNArchGenerator
get
CNNArch
Generator
()
{
return
new
CNNArch2Caffe2
();
}
@Override
public
CNNTrainGenerator
getCNNTrainGenerator
()
{
return
null
;
}
// not implemented yet
};
public
abstract
CNNArchGenerator
getGenerator
();
public
abstract
CNNArchGenerator
getCNNArchGenerator
();
public
abstract
CNNTrainGenerator
getCNNTrainGenerator
();
public
static
Optional
<
Backend
>
getBackendFromString
(
String
backend
){
switch
(
backend
){
...
...
src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLGenerator.java
View file @
9e00e9b6
...
...
@@ -20,19 +20,16 @@
*/
package
de.monticore.lang.monticar.emadl.generator
;
import
com.google.common.base.Charsets
;
import
com.google.common.base.Joiner
;
import
com.google.common.base.Splitter
;
import
com.google.common.base.Charsets
;
import
com.google.common.io.Resources
;
import
de.monticore.io.paths.ModelPath
;
import
de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ComponentSymbol
;
import
de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ExpandedComponentInstanceSymbol
;
import
de.monticore.lang.math._symboltable.MathStatementsSymbol
;
import
de.monticore.lang.monticar.cnnarch.CNNArchGenerator
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol
;
import
de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos
;
import
de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol
;
import
de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage
;
import
de.monticore.lang.monticar.cnntrain.CNNTrainGenerator
;
import
de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol
;
import
de.monticore.lang.monticar.emadl._cocos.EMADLCocos
;
import
de.monticore.lang.monticar.generator.FileContent
;
...
...
@@ -42,13 +39,12 @@ import de.monticore.lang.monticar.generator.cpp.SimulatorIntegrationHelper;
import
de.monticore.lang.monticar.generator.cpp.TypesGeneratorCPP
;
import
de.monticore.lang.monticar.generator.cpp.converter.TypeConverter
;
import
de.monticore.lang.tagging._symboltable.TaggingResolver
;
import
de.monticore.symboltable.GlobalScope
;
import
de.monticore.symboltable.Scope
;
import
de.se_rwth.commons.Splitters
;
import
de.se_rwth.commons.logging.Log
;
import
freemarker.template.TemplateException
;
import
java.io.
*
;
import
java.io.
IOException
;
import
java.nio.charset.Charset
;
import
java.nio.file.Files
;
import
java.nio.file.Path
;
...
...
@@ -60,6 +56,7 @@ public class EMADLGenerator {
private
GeneratorCPP
emamGen
;
private
CNNArchGenerator
cnnArchGenerator
;
private
CNNTrainGenerator
cnnTrainGenerator
;
private
String
modelsPath
;
...
...
@@ -68,7 +65,8 @@ public class EMADLGenerator {
emamGen
=
new
GeneratorCPP
();
emamGen
.
useArmadilloBackend
();
emamGen
.
setGenerationTargetPath
(
"./target/generated-sources-emadl/"
);
cnnArchGenerator
=
backend
.
getGenerator
();
cnnArchGenerator
=
backend
.
getCNNArchGenerator
();
cnnTrainGenerator
=
backend
.
getCNNTrainGenerator
();
}
public
String
getModelsPath
()
{
...
...
@@ -265,70 +263,50 @@ public class EMADLGenerator {
}
public
List
<
FileContent
>
generateCNNTrainer
(
Set
<
ExpandedComponentInstanceSymbol
>
allInstances
,
String
mainComponentName
)
{
List
<
String
>
cnnInstanceNames
=
new
ArrayList
<>();
List
<
ConfigurationSymbol
>
configurations
=
new
ArrayList
<>();
List
<
FileContent
>
fileContents
=
new
ArrayList
<>();
for
(
ExpandedComponentInstanceSymbol
componentInstance
:
allInstances
)
{
ComponentSymbol
component
=
componentInstance
.
getComponentType
().
getReferencedSymbol
();
Optional
<
ArchitectureSymbol
>
architecture
=
component
.
getSpannedScope
().
resolve
(
""
,
ArchitectureSymbol
.
KIND
);
if
(
architecture
.
isPresent
())
{
ConfigurationSymbol
configuration
=
getTrainingConfiguration
(
mainComponentName
,
component
,
componentInstance
);
configurations
.
add
(
configuration
);
cnnInstanceNames
.
add
(
componentInstance
.
getFullName
().
replaceAll
(
"\\."
,
"_"
));
String
trainConfigFilename
;
String
mainComponentConfigFilename
=
mainComponentName
.
replaceAll
(
"\\."
,
"/"
);
String
componentConfigFilename
=
component
.
getFullName
().
replaceAll
(
"\\."
,
"/"
);
String
instanceConfigFilename
=
component
.
getFullName
().
replaceAll
(
"\\."
,
"/"
)
+
"_"
+
component
.
getName
();
if
(
Files
.
exists
(
Paths
.
get
(
getModelsPath
()
+
instanceConfigFilename
+
".cnnt"
)))
{
trainConfigFilename
=
instanceConfigFilename
;
}
else
if
(
Files
.
exists
(
Paths
.
get
(
getModelsPath
()
+
componentConfigFilename
+
".cnnt"
))){
trainConfigFilename
=
componentConfigFilename
;
}
else
if
(
Files
.
exists
(
Paths
.
get
(
getModelsPath
()
+
mainComponentConfigFilename
+
".cnnt"
))){
trainConfigFilename
=
mainComponentConfigFilename
;
}
else
{
Log
.
error
(
"Missing configuration file. "
+
"Could not find a file with any of the following names (only one needed): '"
+
getModelsPath
()
+
instanceConfigFilename
+
".cnnt', '"
+
getModelsPath
()
+
componentConfigFilename
+
".cnnt', '"
+
getModelsPath
()
+
mainComponentConfigFilename
+
".cnnt'."
+
" These files denote respectively the configuration for the single instance, the component or the whole system."
);
return
null
;
}
//should be removed when CNNTrain supports packages
List
<
String
>
names
=
Splitter
.
on
(
"/"
).
splitToList
(
trainConfigFilename
);
trainConfigFilename
=
names
.
get
(
names
.
size
()-
1
);
Path
modelPath
=
Paths
.
get
(
getModelsPath
()
+
Joiner
.
on
(
"/"
).
join
(
names
.
subList
(
0
,
names
.
size
()-
1
)));
ConfigurationSymbol
configuration
=
cnnTrainGenerator
.
getConfigurationSymbol
(
modelPath
,
trainConfigFilename
);
cnnTrainGenerator
.
setInstanceName
(
componentInstance
.
getFullName
().
replaceAll
(
"\\."
,
"_"
));
Map
<
String
,
String
>
fileContentMap
=
cnnTrainGenerator
.
generateStrings
(
configuration
);
for
(
String
fileName
:
fileContentMap
.
keySet
()){
fileContents
.
add
(
new
FileContent
(
fileContentMap
.
get
(
fileName
),
fileName
));
}
}
}
List
<
FileContent
>
fileContents
=
new
ArrayList
<>();
Map
<
String
,
String
>
fileContentMap
=
cnnArchGenerator
.
generateTrainer
(
configurations
,
cnnInstanceNames
,
mainComponentName
);
for
(
String
fileName
:
fileContentMap
.
keySet
()){
fileContents
.
add
(
new
FileContent
(
fileContentMap
.
get
(
fileName
),
fileName
));
}
return
fileContents
;
}
public
ConfigurationSymbol
getTrainingConfiguration
(
String
mainComponentName
,
ComponentSymbol
component
,
ExpandedComponentInstanceSymbol
instance
)
{
String
configFilename
;
String
mainComponentConfigFilename
=
mainComponentName
.
replaceAll
(
"\\."
,
"/"
);
String
componentConfigFilename
=
component
.
getFullName
().
replaceAll
(
"\\."
,
"/"
);
String
instanceConfigFilename
=
component
.
getFullName
().
replaceAll
(
"\\."
,
"/"
)
+
"_"
+
instance
.
getName
();
if
(
Files
.
exists
(
Paths
.
get
(
getModelsPath
()
+
instanceConfigFilename
+
".cnnt"
)))
{
configFilename
=
instanceConfigFilename
;
}
else
if
(
Files
.
exists
(
Paths
.
get
(
getModelsPath
()
+
componentConfigFilename
+
".cnnt"
))){
configFilename
=
componentConfigFilename
;
}
else
if
(
Files
.
exists
(
Paths
.
get
(
getModelsPath
()
+
mainComponentConfigFilename
+
".cnnt"
))){
configFilename
=
mainComponentConfigFilename
;
}
else
{
Log
.
error
(
"Missing configuration file. "
+
"Could not find a file with any of the following names (only one needed): '"
+
getModelsPath
()
+
instanceConfigFilename
+
".cnnt', '"
+
getModelsPath
()
+
componentConfigFilename
+
".cnnt', '"
+
getModelsPath
()
+
mainComponentConfigFilename
+
".cnnt'."
+
" These files denote respectively the configuration for the single instance, the component or the whole system."
);
return
null
;
}
//should be removed when CNNTrain supports packages
List
<
String
>
names
=
Splitter
.
on
(
"/"
).
splitToList
(
configFilename
);
configFilename
=
names
.
get
(
names
.
size
()-
1
);
Path
modelPath
=
Paths
.
get
(
getModelsPath
()
+
Joiner
.
on
(
"/"
).
join
(
names
.
subList
(
0
,
names
.
size
()-
1
)));
//
//CNNTrainGenerator cnnTrainGenerator = new CNNTrainGenerator(); //No need of cnnTrainGenerator since cnnArchGenerator can also generateTrainer()
final
ModelPath
mp
=
new
ModelPath
(
modelPath
);
GlobalScope
trainScope
=
new
GlobalScope
(
mp
,
new
CNNTrainLanguage
());
Optional
<
CNNTrainCompilationUnitSymbol
>
compilationUnit
=
trainScope
.
resolve
(
configFilename
,
CNNTrainCompilationUnitSymbol
.
KIND
);
if
(!
compilationUnit
.
isPresent
()){
Log
.
error
(
"CNNTrainCompilationUnitSymbol is empty. Could not resolve configuration "
+
configFilename
);
System
.
exit
(
1
);
}
CNNTrainCocos
.
checkAll
(
compilationUnit
.
get
());
ConfigurationSymbol
configuration
=
compilationUnit
.
get
().
getConfiguration
();
return
configuration
;
}
public
String
readResource
(
final
String
fileName
,
Charset
charset
)
{
try
{
return
Resources
.
toString
(
Resources
.
getResource
(
fileName
),
charset
);
...
...
src/test/java/de/monticore/lang/monticar/emadl/GenerationTest.java
View file @
9e00e9b6
...
...
@@ -59,7 +59,7 @@ public class GenerationTest extends AbstractSymtabTest {
"cifar10_cifar10Classifier_net.h"
,
"CNNTranslator.h"
,
"cifar10_cifar10Classifier_calculateClass.h"
,
"CNNTrainer_cifar10_
C
ifar10Classifier.py"
));
"CNNTrainer_cifar10_
c
ifar10Classifier
_net
.py"
));
}
@Test
...
...
src/test/resources/models/InstanceTest/NetworkB
_net1
.cnnt
→
src/test/resources/models/InstanceTest/NetworkB.cnnt
View file @
9e00e9b6
configuration NetworkB
_net1
{
configuration NetworkB{
num_epoch:10
batch_size:64
normalize:true
...
...
src/test/resources/target_code/CNNTrainer_cifar10_
C
ifar10Classifier.py
→
src/test/resources/target_code/CNNTrainer_cifar10_
c
ifar10Classifier
_net
.py
View file @
9e00e9b6
File moved
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment