Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2Caffe2
Commits
3cde5fb4
Commit
3cde5fb4
authored
Feb 21, 2019
by
nilsfreyer
Browse files
Merge branch 'master' into oneclick_nn_training
parents
cddf24ac
980d2c19
Pipeline
#106734
passed with stages
in 2 minutes and 57 seconds
Changes
36
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
pom.xml
View file @
3cde5fb4
...
...
@@ -16,7 +16,7 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>
0.2.8
</CNNArch.version>
<CNNTrain.version>
0.2.
5
</CNNTrain.version>
<CNNTrain.version>
0.2.
6
</CNNTrain.version>
<embedded-montiarc-math-generator>
0.1.2-SNAPSHOT
</embedded-montiarc-math-generator>
<!-- .. Libraries .................................................. -->
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/ArchitectureElementData.java
View file @
3cde5fb4
...
...
@@ -20,14 +20,12 @@
*/
package
de.monticore.lang.monticar.cnnarch.caffe2generator
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol
;
import
de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers
;
import
de.se_rwth.commons.logging.Log
;
import
javax.annotation.Nullable
;
import
java.util.Arrays
;
import
java.util.List
;
public
class
ArchitectureElementData
{
...
...
@@ -165,31 +163,21 @@ public class ArchitectureElementData {
}
@Nullable
public
List
<
Integer
>
getPadding
(){
public
Integer
getPadding
(){
return
getPadding
((
LayerSymbol
)
getElement
());
}
@Nullable
public
List
<
Integer
>
getPadding
(
LayerSymbol
layer
){
List
<
Integer
>
kernel
=
layer
.
getIntTupleValue
(
AllPredefinedLayers
.
KERNEL_NAME
).
get
();
List
<
Integer
>
stride
=
layer
.
getIntTupleValue
(
AllPredefinedLayers
.
STRIDE_NAME
).
get
();
ArchTypeSymbol
inputType
=
layer
.
getInputTypes
().
get
(
0
);
ArchTypeSymbol
outputType
=
layer
.
getOutputTypes
().
get
(
0
);
int
heightWithPad
=
kernel
.
get
(
0
)
+
stride
.
get
(
0
)*(
outputType
.
getHeight
()
-
1
);
int
widthWithPad
=
kernel
.
get
(
1
)
+
stride
.
get
(
1
)*(
outputType
.
getWidth
()
-
1
);
int
heightPad
=
Math
.
max
(
0
,
heightWithPad
-
inputType
.
getHeight
());
int
widthPad
=
Math
.
max
(
0
,
widthWithPad
-
inputType
.
getWidth
());
int
topPad
=
(
int
)
Math
.
ceil
(
heightPad
/
2.0
);
int
bottomPad
=
(
int
)
Math
.
floor
(
heightPad
/
2.0
);
int
leftPad
=
(
int
)
Math
.
ceil
(
widthPad
/
2.0
);
int
rightPad
=
(
int
)
Math
.
floor
(
widthPad
/
2.0
);
if
(
topPad
==
0
&&
bottomPad
==
0
&&
leftPad
==
0
&&
rightPad
==
0
){
return
null
;
public
Integer
getPadding
(
LayerSymbol
layer
){
String
padding_type
=
((
LayerSymbol
)
getElement
()).
getStringValue
(
AllPredefinedLayers
.
PADDING_NAME
).
get
();
Integer
pad
=
0
;
if
(
padding_type
.
equals
(
AllPredefinedLayers
.
PADDING_VALID
)){
pad
=
0
;
}
else
if
(
padding_type
.
equals
(
AllPredefinedLayers
.
PADDING_SAME
)){
pad
=
1
;
}
return
Arrays
.
asList
(
0
,
0
,
0
,
0
,
topPad
,
bottomPad
,
leftPad
,
rightPad
)
;
return
pad
;
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/CNNArch2Caffe2.java
View file @
3cde5fb4
...
...
@@ -26,7 +26,6 @@ import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.IOSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage
;
import
de.monticore.lang.monticar.cnnarch.DataPathConfigParser
;
...
...
@@ -53,18 +52,18 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
private
boolean
isSupportedLayer
(
ArchitectureElementSymbol
element
,
LayerSupportChecker
layerChecker
){
List
<
ArchitectureElementSymbol
>
constructLayerElemList
;
if
(!(
element
instanceof
IOSymbol
)
&&
(
element
.
getResolvedThis
().
get
()
instanceof
CompositeElementSymbol
))
{
if
(
element
.
getResolvedThis
().
get
()
instanceof
CompositeElementSymbol
)
{
constructLayerElemList
=
((
CompositeElementSymbol
)
element
.
getResolvedThis
().
get
()).
getElements
();
for
(
ArchitectureElementSymbol
constructedLayerElement
:
constructLayerElemList
)
{
if
(!
isSupportedLayer
(
constructedLayerElement
,
layerChecker
))
return
false
;
if
(!
isSupportedLayer
(
constructedLayerElement
,
layerChecker
))
{
return
false
;
}
}
}
if
(!
layerChecker
.
isSupported
(
element
.
toString
()))
{
Log
.
error
(
"Unsupported layer "
+
"'"
+
element
.
getName
()
+
"'"
+
" for the backend CAFFE2."
);
return
false
;
}
else
{
}
else
{
return
true
;
}
}
...
...
@@ -72,7 +71,9 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
private
boolean
supportCheck
(
ArchitectureSymbol
architecture
){
LayerSupportChecker
layerChecker
=
new
LayerSupportChecker
();
for
(
ArchitectureElementSymbol
element
:
((
CompositeElementSymbol
)
architecture
.
getBody
()).
getElements
()){
if
(!
isSupportedLayer
(
element
,
layerChecker
))
return
false
;
if
(!
isSupportedLayer
(
element
,
layerChecker
))
{
return
false
;
}
}
return
true
;
}
...
...
@@ -84,6 +85,11 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
public
void
setModelPath
(
Path
modelPath
){
this
.
modelPath
=
modelPath
.
toString
();
}
private
static
void
quitGeneration
(){
Log
.
error
(
"Code generation is aborted"
);
System
.
exit
(
1
);
}
public
CNNArch2Caffe2
()
{
setGenerationTargetPath
(
"./target/generated-sources-cnnarch/"
);
...
...
@@ -116,13 +122,12 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
Optional
<
CNNArchCompilationUnitSymbol
>
compilationUnit
=
scope
.
resolve
(
rootModelName
,
CNNArchCompilationUnitSymbol
.
KIND
);
if
(!
compilationUnit
.
isPresent
()){
Log
.
error
(
"could not resolve architecture "
+
rootModelName
);
System
.
exit
(
1
);
quitGeneration
(
);
}
CNNArchCocos
.
checkAll
(
compilationUnit
.
get
());
if
(!
supportCheck
(
compilationUnit
.
get
().
getArchitecture
())){
Log
.
error
(
"Code generation aborted."
);
System
.
exit
(
1
);
quitGeneration
();
}
try
{
...
...
@@ -132,8 +137,7 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
compilationUnit
.
get
().
getArchitecture
().
setDataPath
(
dataPath
);
compilationUnit
.
get
().
getArchitecture
().
setComponentName
(
rootModelName
);
generateFiles
(
compilationUnit
.
get
().
getArchitecture
());
}
catch
(
IOException
e
){
}
catch
(
IOException
e
){
Log
.
error
(
e
.
toString
());
}
}
...
...
@@ -188,7 +192,7 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
try
{
generateFromFilecontentsMap
(
fileContentMap
);
}
catch
(
IOException
e
)
{
e
.
printStackTrac
e
();
Log
.
error
(
"CMake file could not be generated"
+
e
.
getMessag
e
()
)
;
}
}
...
...
@@ -211,8 +215,8 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
cMakeConfig
.
addCMakeCommand
(
"set(LIBS ${LIBS} -lprotobuf -lglog -lgflags)"
);
cMakeConfig
.
addCMakeCommand
(
"find_package(CUDA)"
+
"\n"
+
"set(INCLUDE_DIRS ${INCLUDE_DIRS} ${CUDA_INCLUDE_DIRS})"
+
"\n"
+
"set(LIBS ${LIBS} ${CUDA_LIBRARIES} ${CUDA_curand_LIBRARY})"
+
"\n"
);
//Needed since CUDA cannot be found correctly (including CUDA_curand_LIBRARY) and as optional using CMakeFindModule
+
"set(LIBS ${LIBS} ${CUDA_LIBRARIES} ${CUDA_curand_LIBRARY})"
+
"\n"
);
//Needed since CUDA cannot be found correctly (including CUDA_curand_LIBRARY)
cMakeConfig
.
addCMakeCommand
(
"if(CUDA_FOUND)"
+
"\n"
+
" set(LIBS ${LIBS} caffe2 caffe2_gpu)"
+
"\n"
+
"else()"
+
"\n"
+
" set(LIBS ${LIBS} caffe2)"
+
"\n"
+
"endif()"
);
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/CNNArch2Caffe2Cli.java
View file @
3cde5fb4
...
...
@@ -19,6 +19,7 @@
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnnarch.caffe2generator
;
import
de.se_rwth.commons.logging.Log
;
import
org.apache.commons.cli.*
;
...
...
@@ -73,13 +74,18 @@ public class CNNArch2Caffe2Cli {
try
{
cliArgs
=
parser
.
parse
(
options
,
args
);
}
catch
(
ParseException
e
)
{
System
.
err
.
println
(
"argument parsing exception: "
+
e
.
getMessage
());
System
.
exit
(
1
);
Log
.
error
(
"argument parsing exception: "
+
e
.
getMessage
());
quitGeneration
(
);
return
null
;
}
return
cliArgs
;
}
private
static
void
quitGeneration
(){
Log
.
error
(
"Code generation is aborted"
);
System
.
exit
(
1
);
}
private
static
void
runGenerator
(
CommandLine
cliArgs
)
{
Path
modelsDirPath
=
Paths
.
get
(
cliArgs
.
getOptionValue
(
OPTION_MODELS_PATH
.
getOpt
()));
String
rootModelName
=
cliArgs
.
getOptionValue
(
OPTION_ROOT_MODEL
.
getOpt
());
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/CNNArchTemplateController.java
View file @
3cde5fb4
...
...
@@ -37,6 +37,7 @@ public class CNNArchTemplateController {
private
LayerNameCreator
nameManager
;
private
ArchitectureSymbol
architecture
;
private
String
loss
;
//temporary attributes. They are set after calling process()
private
Writer
writer
;
...
...
@@ -44,6 +45,8 @@ public class CNNArchTemplateController {
private
Target
targetLanguage
;
private
ArchitectureElementData
dataElement
;
public
static
final
String
CROSS_ENTROPY
=
"cross_entropy"
;
public
static
final
String
EUCLIDEAN
=
"euclidean"
;
public
CNNArchTemplateController
(
ArchitectureSymbol
architecture
)
{
setArchitecture
(
architecture
);
...
...
@@ -96,8 +99,7 @@ public class CNNArchTemplateController {
if
(
isSoftmaxOutput
(
layer
)
||
isLogisticRegressionOutput
(
layer
)){
inputNames
=
getLayerInputs
(
layer
.
getInputElement
().
get
());
}
else
{
}
else
{
for
(
ArchitectureElementSymbol
input
:
layer
.
getPrevious
())
{
if
(
input
.
getOutputTypes
().
size
()
==
1
)
{
inputNames
.
add
(
getName
(
input
));
...
...
@@ -132,6 +134,9 @@ public class CNNArchTemplateController {
return
getArchitecture
().
getComponentName
();
}
public
String
getArchitectureLoss
(){
return
this
.
loss
;
}
public
void
include
(
String
relativePath
,
String
templateWithoutFileEnding
,
Writer
writer
){
String
templatePath
=
relativePath
+
templateWithoutFileEnding
+
FTL_FILE_ENDING
;
...
...
@@ -148,12 +153,10 @@ public class CNNArchTemplateController {
if
(
ioElement
.
isAtomic
()){
if
(
ioElement
.
isInput
()){
include
(
TEMPLATE_ELEMENTS_DIR_PATH
,
"Input"
,
writer
);
}
else
{
}
else
{
include
(
TEMPLATE_ELEMENTS_DIR_PATH
,
"Output"
,
writer
);
}
}
else
{
}
else
{
include
(
ioElement
.
getResolvedThis
().
get
(),
writer
);
}
...
...
@@ -170,8 +173,7 @@ public class CNNArchTemplateController {
String
templateName
=
layer
.
getDeclaration
().
getName
();
include
(
TEMPLATE_ELEMENTS_DIR_PATH
,
templateName
,
writer
);
}
}
else
{
}
else
{
include
(
layer
.
getResolvedThis
().
get
(),
writer
);
}
...
...
@@ -192,11 +194,9 @@ public class CNNArchTemplateController {
public
void
include
(
ArchitectureElementSymbol
architectureElement
,
Writer
writer
){
if
(
architectureElement
instanceof
CompositeElementSymbol
){
include
((
CompositeElementSymbol
)
architectureElement
,
writer
);
}
else
if
(
architectureElement
instanceof
LayerSymbol
){
}
else
if
(
architectureElement
instanceof
LayerSymbol
){
include
((
LayerSymbol
)
architectureElement
,
writer
);
}
else
{
}
else
{
include
((
IOSymbol
)
architectureElement
,
writer
);
}
}
...
...
@@ -209,15 +209,15 @@ public class CNNArchTemplateController {
}
public
Map
.
Entry
<
String
,
String
>
process
(
String
templateNameWithoutEnding
,
Target
targetLanguage
){
StringWriter
w
riter
=
new
StringWriter
();
StringWriter
newW
riter
=
new
StringWriter
();
this
.
mainTemplateNameWithoutEnding
=
templateNameWithoutEnding
;
this
.
targetLanguage
=
targetLanguage
;
this
.
writer
=
w
riter
;
this
.
writer
=
newW
riter
;
include
(
""
,
templateNameWithoutEnding
,
w
riter
);
include
(
""
,
templateNameWithoutEnding
,
newW
riter
);
String
fileEnding
=
targetLanguage
.
toString
();
String
fileName
=
getFileNameWithoutEnding
()
+
fileEnding
;
Map
.
Entry
<
String
,
String
>
fileContent
=
new
AbstractMap
.
SimpleEntry
<>(
fileName
,
w
riter
.
toString
());
Map
.
Entry
<
String
,
String
>
fileContent
=
new
AbstractMap
.
SimpleEntry
<>(
fileName
,
newW
riter
.
toString
());
this
.
mainTemplateNameWithoutEnding
=
null
;
this
.
targetLanguage
=
null
;
...
...
@@ -246,27 +246,39 @@ public class CNNArchTemplateController {
public
boolean
isLogisticRegressionOutput
(
ArchitectureElementSymbol
architectureElement
){
return
isTOutput
(
Sigmoid
.
class
,
architectureElement
);
if
(
isTOutput
(
Sigmoid
.
class
,
architectureElement
)){
this
.
loss
=
CROSS_ENTROPY
;
return
true
;
}
return
false
;
}
public
boolean
isLinearRegressionOutput
(
ArchitectureElementSymbol
architectureElement
){
return
architectureElement
.
isOutput
()
if
(
architectureElement
.
isOutput
()
&&
!
isLogisticRegressionOutput
(
architectureElement
)
&&
!
isSoftmaxOutput
(
architectureElement
);
&&
!
isSoftmaxOutput
(
architectureElement
)){
this
.
loss
=
EUCLIDEAN
;
return
true
;
}
return
false
;
}
public
boolean
isSoftmaxOutput
(
ArchitectureElementSymbol
architectureElement
){
return
isTOutput
(
Softmax
.
class
,
architectureElement
);
if
(
isTOutput
(
Softmax
.
class
,
architectureElement
)){
this
.
loss
=
CROSS_ENTROPY
;
return
true
;
}
return
false
;
}
private
boolean
isTOutput
(
Class
inputPredefinedLayerClass
,
ArchitectureElementSymbol
architectureElement
){
if
(
architectureElement
.
isOutput
()
){
if
(
architectureElement
.
getInputElement
().
isPresent
()
&&
architectureElement
.
getInputElement
().
get
()
instanceof
LayerSymbol
){
LayerSymbol
inputLayer
=
(
LayerSymbol
)
architectureElement
.
getInputElement
().
get
()
;
if
(
inputPredefinedLayerClass
.
isInstance
(
inputLayer
.
getDeclaration
())){
return
true
;
}
if
(
architectureElement
.
isOutput
()
&&
architectureElement
.
getInputElement
().
isPresent
()
&&
architectureElement
.
getInputElement
().
get
()
instanceof
LayerSymbol
){
LayerSymbol
inputLayer
=
(
LayerSymbol
)
architectureElement
.
getInputElement
().
get
();
if
(
inputPredefinedLayerClass
.
isInstance
(
inputLayer
.
getDeclaration
())){
return
true
;
}
}
return
false
;
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/CNNTrain2Caffe2.java
View file @
3cde5fb4
...
...
@@ -8,6 +8,7 @@ import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
import
de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol
;
import
de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage
;
import
de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol
;
import
de.monticore.lang.monticar.cnntrain._symboltable.OptimizerSymbol
;
import
de.monticore.lang.monticar.generator.FileContent
;
import
de.monticore.lang.monticar.generator.cpp.GeneratorCPP
;
import
de.monticore.symboltable.GlobalScope
;
...
...
@@ -37,7 +38,9 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator {
it
=
configuration
.
getEntryMap
().
keySet
().
iterator
();
while
(
it
.
hasNext
())
{
String
key
=
it
.
next
().
toString
();
if
(
funcChecker
.
getUnsupportedElemList
().
contains
(
key
))
it
.
remove
();
if
(
funcChecker
.
getUnsupportedElemList
().
contains
(
key
))
{
it
.
remove
();
}
}
}
...
...
@@ -47,17 +50,25 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator {
ASTOptimizerEntry
astOptimizer
=
(
ASTOptimizerEntry
)
configuration
.
getOptimizer
().
getAstNode
().
get
();
astOptimizer
.
accept
(
funcChecker
);
if
(
funcChecker
.
getUnsupportedElemList
().
contains
(
funcChecker
.
unsupportedOptFlag
))
{
configuration
.
setOptimizer
(
null
);
OptimizerSymbol
adamOptimizer
=
new
OptimizerSymbol
(
"adam"
);
configuration
.
setOptimizer
(
adamOptimizer
);
//Set default as adam optimizer
}
else
{
Iterator
it
=
configuration
.
getOptimizer
().
getOptimizerParamMap
().
keySet
().
iterator
();
while
(
it
.
hasNext
())
{
String
key
=
it
.
next
().
toString
();
if
(
funcChecker
.
getUnsupportedElemList
().
contains
(
key
))
it
.
remove
();
if
(
funcChecker
.
getUnsupportedElemList
().
contains
(
key
))
{
it
.
remove
();
}
}
}
}
}
private
static
void
quitGeneration
(){
Log
.
error
(
"Code generation is aborted"
);
System
.
exit
(
1
);
}
public
CNNTrain2Caffe2
()
{
setGenerationTargetPath
(
"./target/generated-sources-cnnarch/"
);
}
...
...
@@ -89,7 +100,7 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator {
Optional
<
CNNTrainCompilationUnitSymbol
>
compilationUnit
=
scope
.
resolve
(
rootModelName
,
CNNTrainCompilationUnitSymbol
.
KIND
);
if
(!
compilationUnit
.
isPresent
())
{
Log
.
error
(
"could not resolve training configuration "
+
rootModelName
);
System
.
exit
(
1
);
quitGeneration
(
);
}
setInstanceName
(
compilationUnit
.
get
().
getFullName
());
CNNTrainCocos
.
checkAll
(
compilationUnit
.
get
());
...
...
@@ -107,7 +118,7 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator {
genCPP
.
generateFile
(
new
FileContent
(
fileContents
.
get
(
fileName
),
fileName
));
}
}
catch
(
IOException
e
)
{
e
.
printStackTrac
e
();
Log
.
error
(
"CNNTrainer file could not be generated"
+
e
.
getMessag
e
()
)
;
}
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/ConfigurationData.java
View file @
3cde5fb4
...
...
@@ -67,6 +67,13 @@ public class ConfigurationData {
return
getConfiguration
().
getEntry
(
"eval_metric"
).
getValue
().
toString
();
}
public
String
getLoss
()
{
if
(!
getConfiguration
().
getEntryMap
().
containsKey
(
"loss"
))
{
return
null
;
}
return
getConfiguration
().
getEntry
(
"loss"
).
getValue
().
toString
();
}
public
String
getOptimizerName
()
{
if
(
getConfiguration
().
getOptimizer
()
==
null
)
{
return
null
;
...
...
@@ -89,8 +96,7 @@ public class ConfigurationData {
Class
realClass
=
entry
.
getValue
().
getValue
().
getValue
().
getClass
();
if
(
realClass
==
Boolean
.
class
)
{
valueAsString
=
(
Boolean
)
entry
.
getValue
().
getValue
().
getValue
()
?
"True"
:
"False"
;
}
else
if
(
lrPolicyClasses
.
contains
(
realClass
))
{
}
else
if
(
lrPolicyClasses
.
contains
(
realClass
))
{
valueAsString
=
"'"
+
valueAsString
+
"'"
;
}
mapToStrings
.
put
(
paramName
,
valueAsString
);
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/LayerNameCreator.java
View file @
3cde5fb4
...
...
@@ -47,17 +47,14 @@ public class LayerNameCreator {
protected
int
name
(
ArchitectureElementSymbol
architectureElement
,
int
stage
,
List
<
Integer
>
streamIndices
){
if
(
architectureElement
instanceof
CompositeElementSymbol
){
return
nameComposite
((
CompositeElementSymbol
)
architectureElement
,
stage
,
streamIndices
);
}
else
{
}
else
{
if
(
architectureElement
.
isAtomic
()){
if
(
architectureElement
.
getMaxSerialLength
().
get
()
>
0
){
return
add
(
architectureElement
,
stage
,
streamIndices
);
}
else
{
}
else
{
return
stage
;
}
}
else
{
}
else
{
ArchitectureElementSymbol
resolvedElement
=
architectureElement
.
getResolvedThis
().
get
();
return
name
(
resolvedElement
,
stage
,
streamIndices
);
}
...
...
@@ -78,8 +75,7 @@ public class LayerNameCreator {
streamIndices
.
remove
(
lastIndex
);
return
Collections
.
max
(
endStages
)
+
1
;
}
else
{
}
else
{
int
endStage
=
stage
;
for
(
ArchitectureElementSymbol
subElement
:
compositeElement
.
getElements
()){
endStage
=
name
(
subElement
,
endStage
,
streamIndices
);
...
...
@@ -113,8 +109,7 @@ public class LayerNameCreator {
name
=
name
+
"_"
+
arrayAccess
+
"_"
;
}
return
name
;
}
else
{
}
else
{
return
createBaseName
(
architectureElement
)
+
stage
+
createStreamPostfix
(
streamIndices
)
+
"_"
;
}
}
...
...
@@ -132,11 +127,9 @@ public class LayerNameCreator {
}
else
{
return
layerDeclaration
.
getName
().
toLowerCase
();
}
}
else
if
(
architectureElement
instanceof
CompositeElementSymbol
){
}
else
if
(
architectureElement
instanceof
CompositeElementSymbol
){
return
"group"
;
}
else
{
}
else
{
return
architectureElement
.
getName
();
}
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/TemplateConfiguration.java
View file @
3cde5fb4
...
...
@@ -43,6 +43,11 @@ public class TemplateConfiguration {
configuration
.
setTemplateExceptionHandler
(
TemplateExceptionHandler
.
RETHROW_HANDLER
);
}
private
static
void
quitGeneration
(){
Log
.
error
(
"Code generation is aborted"
);
System
.
exit
(
1
);
}
public
Configuration
getConfiguration
()
{
return
configuration
;
}
...
...
@@ -58,14 +63,12 @@ public class TemplateConfiguration {
try
{
Template
template
=
TemplateConfiguration
.
get
().
getTemplate
(
templatePath
);
template
.
process
(
ftlContext
,
writer
);
}
catch
(
IOException
e
)
{
}
catch
(
IOException
e
)
{
Log
.
error
(
"Freemarker could not find template "
+
templatePath
+
" :\n"
+
e
.
getMessage
());
System
.
exit
(
1
);
}
catch
(
TemplateException
e
){
quitGeneration
();
}
catch
(
TemplateException
e
){
Log
.
error
(
"An exception occured in template "
+
templatePath
+
" :\n"
+
e
.
getMessage
());
System
.
exit
(
1
);
quitGeneration
(
);
}
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/TrainParamSupportChecker.java
View file @
3cde5fb4
...
...
@@ -25,12 +25,14 @@ public class TrainParamSupportChecker implements CNNTrainVisitor {
public
TrainParamSupportChecker
()
{
}
public
String
unsupportedOptFlag
=
"unsupported_optimizer"
;
public
static
final
String
unsupportedOptFlag
=
"unsupported_optimizer"
;
public
List
getUnsupportedElemList
(){
return
this
.
unsupportedElemList
;
}
//Empty visit method denotes that the corresponding training parameter is supported.
//To set a training parameter as unsupported, add the corresponding node to the unsupportedElemList
public
void
visit
(
ASTNumEpochEntry
node
){}
public
void
visit
(
ASTBatchSizeEntry
node
){}
...
...
@@ -76,10 +78,7 @@ public class TrainParamSupportChecker implements CNNTrainVisitor {
public
void
visit
(
ASTWeightDecayEntry
node
){}
public
void
visit
(
ASTLRDecayEntry
node
){
printUnsupportedOptimizerParam
(
node
.
getName
());
this
.
unsupportedElemList
.
add
(
node
.
getName
());
}
public
void
visit
(
ASTLRDecayEntry
node
){}
public
void
visit
(
ASTLRPolicyEntry
node
){}
...
...
src/main/resources/templates/caffe2/CNNCreator.ftl
View file @
3cde5fb4
...
...
@@ -3,10 +3,12 @@ from caffe2.python.predictor import mobile_exporter
from caffe2.proto import caffe2_pb2
import numpy as np
import math
import datetime
import logging
import os
import sys
import lmdb
class ${tc.fileNameWithoutEnding}:
module = None
...
...
@@ -27,6 +29,15 @@ class ${tc.fileNameWithoutEnding}:
return iterations_int
def get_epoch_as_iter(self, num_epoch, batch_size, dataset_size): #To print metric durint training process
#Force floating point calculation