Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
What's new
7
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2Caffe2
Commits
8dbe6ce3
Commit
8dbe6ce3
authored
Feb 18, 2019
by
Evgeny Kusmenko
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'code_review_changes' into 'master'
Code review changes See merge request
!23
parents
0cd223e5
b4056077
Pipeline
#105792
passed with stages
in 7 minutes and 33 seconds
Changes
9
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
77 additions
and
69 deletions
+77
-69
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/ArchitectureElementData.java
...icar/cnnarch/caffe2generator/ArchitectureElementData.java
+1
-4
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/CNNArch2Caffe2.java
...lang/monticar/cnnarch/caffe2generator/CNNArch2Caffe2.java
+19
-14
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/CNNArch2Caffe2Cli.java
...g/monticar/cnnarch/caffe2generator/CNNArch2Caffe2Cli.java
+8
-2
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/CNNArchTemplateController.java
...ar/cnnarch/caffe2generator/CNNArchTemplateController.java
+16
-22
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/CNNTrain2Caffe2.java
...ang/monticar/cnnarch/caffe2generator/CNNTrain2Caffe2.java
+13
-4
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/ConfigurationData.java
...g/monticar/cnnarch/caffe2generator/ConfigurationData.java
+1
-2
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/LayerNameCreator.java
...ng/monticar/cnnarch/caffe2generator/LayerNameCreator.java
+7
-14
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/TemplateConfiguration.java
...nticar/cnnarch/caffe2generator/TemplateConfiguration.java
+9
-6
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/TrainParamSupportChecker.java
...car/cnnarch/caffe2generator/TrainParamSupportChecker.java
+3
-1
No files found.
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/ArchitectureElementData.java
View file @
8dbe6ce3
...
...
@@ -20,14 +20,12 @@
*/
package
de.monticore.lang.monticar.cnnarch.caffe2generator
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol
;
import
de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers
;
import
de.se_rwth.commons.logging.Log
;
import
javax.annotation.Nullable
;
import
java.util.Arrays
;
import
java.util.List
;
public
class
ArchitectureElementData
{
...
...
@@ -176,8 +174,7 @@ public class ArchitectureElementData {
if
(
padding_type
.
equals
(
AllPredefinedLayers
.
PADDING_VALID
)){
pad
=
0
;
}
else
if
(
padding_type
.
equals
(
AllPredefinedLayers
.
PADDING_SAME
)){
}
else
if
(
padding_type
.
equals
(
AllPredefinedLayers
.
PADDING_SAME
)){
pad
=
1
;
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/CNNArch2Caffe2.java
View file @
8dbe6ce3
...
...
@@ -51,18 +51,18 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
private
boolean
isSupportedLayer
(
ArchitectureElementSymbol
element
,
LayerSupportChecker
layerChecker
){
List
<
ArchitectureElementSymbol
>
constructLayerElemList
;
if
(!(
element
instanceof
IOSymbol
)
&&
(
element
.
getResolvedThis
().
get
()
instanceof
CompositeElementSymbol
))
{
if
(!(
element
instanceof
IOSymbol
)
&&
(
element
.
getResolvedThis
().
get
()
instanceof
CompositeElementSymbol
))
{
constructLayerElemList
=
((
CompositeElementSymbol
)
element
.
getResolvedThis
().
get
()).
getElements
();
for
(
ArchitectureElementSymbol
constructedLayerElement
:
constructLayerElemList
)
{
if
(!
isSupportedLayer
(
constructedLayerElement
,
layerChecker
))
return
false
;
if
(!
isSupportedLayer
(
constructedLayerElement
,
layerChecker
))
{
return
false
;
}
}
}
if
(!
layerChecker
.
isSupported
(
element
.
toString
()))
{
Log
.
error
(
"Unsupported layer "
+
"'"
+
element
.
getName
()
+
"'"
+
" for the backend CAFFE2."
);
return
false
;
}
else
{
}
else
{
return
true
;
}
}
...
...
@@ -70,11 +70,18 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
private
boolean
supportCheck
(
ArchitectureSymbol
architecture
){
LayerSupportChecker
layerChecker
=
new
LayerSupportChecker
();
for
(
ArchitectureElementSymbol
element
:
((
CompositeElementSymbol
)
architecture
.
getBody
()).
getElements
()){
if
(!
isSupportedLayer
(
element
,
layerChecker
))
return
false
;
if
(!
isSupportedLayer
(
element
,
layerChecker
))
{
return
false
;
}
}
return
true
;
}
private
static
void
quitGeneration
(){
Log
.
error
(
"Code generation is aborted"
);
System
.
exit
(
1
);
}
public
CNNArch2Caffe2
()
{
setGenerationTargetPath
(
"./target/generated-sources-cnnarch/"
);
}
...
...
@@ -105,19 +112,17 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
Optional
<
CNNArchCompilationUnitSymbol
>
compilationUnit
=
scope
.
resolve
(
rootModelName
,
CNNArchCompilationUnitSymbol
.
KIND
);
if
(!
compilationUnit
.
isPresent
()){
Log
.
error
(
"could not resolve architecture "
+
rootModelName
);
System
.
exit
(
1
);
quitGeneration
(
);
}
CNNArchCocos
.
checkAll
(
compilationUnit
.
get
());
if
(!
supportCheck
(
compilationUnit
.
get
().
getArchitecture
())){
Log
.
error
(
"Code generation aborted."
);
System
.
exit
(
1
);
quitGeneration
();
}
try
{
generateFiles
(
compilationUnit
.
get
().
getArchitecture
());
}
catch
(
IOException
e
){
}
catch
(
IOException
e
){
Log
.
error
(
e
.
toString
());
}
}
...
...
@@ -172,7 +177,7 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
try
{
generateFromFilecontentsMap
(
fileContentMap
);
}
catch
(
IOException
e
)
{
e
.
printStackTrac
e
();
Log
.
error
(
"CMake file could not be generated"
+
e
.
getMessag
e
()
)
;
}
}
...
...
@@ -195,8 +200,8 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
cMakeConfig
.
addCMakeCommand
(
"set(LIBS ${LIBS} -lprotobuf -lglog -lgflags)"
);
cMakeConfig
.
addCMakeCommand
(
"find_package(CUDA)"
+
"\n"
+
"set(INCLUDE_DIRS ${INCLUDE_DIRS} ${CUDA_INCLUDE_DIRS})"
+
"\n"
+
"set(LIBS ${LIBS} ${CUDA_LIBRARIES} ${CUDA_curand_LIBRARY})"
+
"\n"
);
//Needed since CUDA cannot be found correctly (including CUDA_curand_LIBRARY) and as optional using CMakeFindModule
+
"set(LIBS ${LIBS} ${CUDA_LIBRARIES} ${CUDA_curand_LIBRARY})"
+
"\n"
);
//Needed since CUDA cannot be found correctly (including CUDA_curand_LIBRARY)
cMakeConfig
.
addCMakeCommand
(
"if(CUDA_FOUND)"
+
"\n"
+
" set(LIBS ${LIBS} caffe2 caffe2_gpu)"
+
"\n"
+
"else()"
+
"\n"
+
" set(LIBS ${LIBS} caffe2)"
+
"\n"
+
"endif()"
);
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/CNNArch2Caffe2Cli.java
View file @
8dbe6ce3
...
...
@@ -19,6 +19,7 @@
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnnarch.caffe2generator
;
import
de.se_rwth.commons.logging.Log
;
import
org.apache.commons.cli.*
;
...
...
@@ -73,13 +74,18 @@ public class CNNArch2Caffe2Cli {
try
{
cliArgs
=
parser
.
parse
(
options
,
args
);
}
catch
(
ParseException
e
)
{
System
.
err
.
println
(
"argument parsing exception: "
+
e
.
getMessage
());
System
.
exit
(
1
);
Log
.
error
(
"argument parsing exception: "
+
e
.
getMessage
());
quitGeneration
(
);
return
null
;
}
return
cliArgs
;
}
private
static
void
quitGeneration
(){
Log
.
error
(
"Code generation is aborted"
);
System
.
exit
(
1
);
}
private
static
void
runGenerator
(
CommandLine
cliArgs
)
{
Path
modelsDirPath
=
Paths
.
get
(
cliArgs
.
getOptionValue
(
OPTION_MODELS_PATH
.
getOpt
()));
String
rootModelName
=
cliArgs
.
getOptionValue
(
OPTION_ROOT_MODEL
.
getOpt
());
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/CNNArchTemplateController.java
View file @
8dbe6ce3
...
...
@@ -95,8 +95,7 @@ public class CNNArchTemplateController {
if
(
isSoftmaxOutput
(
layer
)
||
isLogisticRegressionOutput
(
layer
)){
inputNames
=
getLayerInputs
(
layer
.
getInputElement
().
get
());
}
else
{
}
else
{
for
(
ArchitectureElementSymbol
input
:
layer
.
getPrevious
())
{
if
(
input
.
getOutputTypes
().
size
()
==
1
)
{
inputNames
.
add
(
getName
(
input
));
...
...
@@ -146,12 +145,10 @@ public class CNNArchTemplateController {
if
(
ioElement
.
isAtomic
()){
if
(
ioElement
.
isInput
()){
include
(
TEMPLATE_ELEMENTS_DIR_PATH
,
"Input"
,
writer
);
}
else
{
}
else
{
include
(
TEMPLATE_ELEMENTS_DIR_PATH
,
"Output"
,
writer
);
}
}
else
{
}
else
{
include
(
ioElement
.
getResolvedThis
().
get
(),
writer
);
}
...
...
@@ -168,8 +165,7 @@ public class CNNArchTemplateController {
String
templateName
=
layer
.
getDeclaration
().
getName
();
include
(
TEMPLATE_ELEMENTS_DIR_PATH
,
templateName
,
writer
);
}
}
else
{
}
else
{
include
(
layer
.
getResolvedThis
().
get
(),
writer
);
}
...
...
@@ -190,11 +186,9 @@ public class CNNArchTemplateController {
public
void
include
(
ArchitectureElementSymbol
architectureElement
,
Writer
writer
){
if
(
architectureElement
instanceof
CompositeElementSymbol
){
include
((
CompositeElementSymbol
)
architectureElement
,
writer
);
}
else
if
(
architectureElement
instanceof
LayerSymbol
){
}
else
if
(
architectureElement
instanceof
LayerSymbol
){
include
((
LayerSymbol
)
architectureElement
,
writer
);
}
else
{
}
else
{
include
((
IOSymbol
)
architectureElement
,
writer
);
}
}
...
...
@@ -207,15 +201,15 @@ public class CNNArchTemplateController {
}
public
Map
.
Entry
<
String
,
String
>
process
(
String
templateNameWithoutEnding
,
Target
targetLanguage
){
StringWriter
w
riter
=
new
StringWriter
();
StringWriter
newW
riter
=
new
StringWriter
();
this
.
mainTemplateNameWithoutEnding
=
templateNameWithoutEnding
;
this
.
targetLanguage
=
targetLanguage
;
this
.
writer
=
w
riter
;
this
.
writer
=
newW
riter
;
include
(
""
,
templateNameWithoutEnding
,
w
riter
);
include
(
""
,
templateNameWithoutEnding
,
newW
riter
);
String
fileEnding
=
targetLanguage
.
toString
();
String
fileName
=
getFileNameWithoutEnding
()
+
fileEnding
;
Map
.
Entry
<
String
,
String
>
fileContent
=
new
AbstractMap
.
SimpleEntry
<>(
fileName
,
w
riter
.
toString
());
Map
.
Entry
<
String
,
String
>
fileContent
=
new
AbstractMap
.
SimpleEntry
<>(
fileName
,
newW
riter
.
toString
());
this
.
mainTemplateNameWithoutEnding
=
null
;
this
.
targetLanguage
=
null
;
...
...
@@ -271,12 +265,12 @@ public class CNNArchTemplateController {
}
private
boolean
isTOutput
(
Class
inputPredefinedLayerClass
,
ArchitectureElementSymbol
architectureElement
){
if
(
architectureElement
.
isOutput
()
){
if
(
architectureElement
.
getInputElement
().
isPresent
()
&&
architectureElement
.
getInputElement
().
get
()
instanceof
LayerSymbol
){
LayerSymbol
inputLayer
=
(
LayerSymbol
)
architectureElement
.
getInputElement
().
get
()
;
if
(
inputPredefinedLayerClass
.
isInstance
(
inputLayer
.
getDeclaration
())){
return
true
;
}
if
(
architectureElement
.
isOutput
()
&&
architectureElement
.
getInputElement
().
isPresent
()
&&
architectureElement
.
getInputElement
().
get
()
instanceof
LayerSymbol
){
LayerSymbol
inputLayer
=
(
LayerSymbol
)
architectureElement
.
getInputElement
().
get
();
if
(
inputPredefinedLayerClass
.
isInstance
(
inputLayer
.
getDeclaration
())){
return
true
;
}
}
return
false
;
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/CNNTrain2Caffe2.java
View file @
8dbe6ce3
...
...
@@ -37,7 +37,9 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator {
it
=
configuration
.
getEntryMap
().
keySet
().
iterator
();
while
(
it
.
hasNext
())
{
String
key
=
it
.
next
().
toString
();
if
(
funcChecker
.
getUnsupportedElemList
().
contains
(
key
))
it
.
remove
();
if
(
funcChecker
.
getUnsupportedElemList
().
contains
(
key
))
{
it
.
remove
();
}
}
}
...
...
@@ -52,12 +54,19 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator {
Iterator
it
=
configuration
.
getOptimizer
().
getOptimizerParamMap
().
keySet
().
iterator
();
while
(
it
.
hasNext
())
{
String
key
=
it
.
next
().
toString
();
if
(
funcChecker
.
getUnsupportedElemList
().
contains
(
key
))
it
.
remove
();
if
(
funcChecker
.
getUnsupportedElemList
().
contains
(
key
))
{
it
.
remove
();
}
}
}
}
}
private
static
void
quitGeneration
(){
Log
.
error
(
"Code generation is aborted"
);
System
.
exit
(
1
);
}
public
CNNTrain2Caffe2
()
{
setGenerationTargetPath
(
"./target/generated-sources-cnnarch/"
);
}
...
...
@@ -89,7 +98,7 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator {
Optional
<
CNNTrainCompilationUnitSymbol
>
compilationUnit
=
scope
.
resolve
(
rootModelName
,
CNNTrainCompilationUnitSymbol
.
KIND
);
if
(!
compilationUnit
.
isPresent
())
{
Log
.
error
(
"could not resolve training configuration "
+
rootModelName
);
System
.
exit
(
1
);
quitGeneration
(
);
}
setInstanceName
(
compilationUnit
.
get
().
getFullName
());
CNNTrainCocos
.
checkAll
(
compilationUnit
.
get
());
...
...
@@ -107,7 +116,7 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator {
genCPP
.
generateFile
(
new
FileContent
(
fileContents
.
get
(
fileName
),
fileName
));
}
}
catch
(
IOException
e
)
{
e
.
printStackTrac
e
();
Log
.
error
(
"CNNTrainer file could not be generated"
+
e
.
getMessag
e
()
)
;
}
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/ConfigurationData.java
View file @
8dbe6ce3
...
...
@@ -96,8 +96,7 @@ public class ConfigurationData {
Class
realClass
=
entry
.
getValue
().
getValue
().
getValue
().
getClass
();
if
(
realClass
==
Boolean
.
class
)
{
valueAsString
=
(
Boolean
)
entry
.
getValue
().
getValue
().
getValue
()
?
"True"
:
"False"
;
}
else
if
(
lrPolicyClasses
.
contains
(
realClass
))
{
}
else
if
(
lrPolicyClasses
.
contains
(
realClass
))
{
valueAsString
=
"'"
+
valueAsString
+
"'"
;
}
mapToStrings
.
put
(
paramName
,
valueAsString
);
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/LayerNameCreator.java
View file @
8dbe6ce3
...
...
@@ -47,17 +47,14 @@ public class LayerNameCreator {
protected
int
name
(
ArchitectureElementSymbol
architectureElement
,
int
stage
,
List
<
Integer
>
streamIndices
){
if
(
architectureElement
instanceof
CompositeElementSymbol
){
return
nameComposite
((
CompositeElementSymbol
)
architectureElement
,
stage
,
streamIndices
);
}
else
{
}
else
{
if
(
architectureElement
.
isAtomic
()){
if
(
architectureElement
.
getMaxSerialLength
().
get
()
>
0
){
return
add
(
architectureElement
,
stage
,
streamIndices
);
}
else
{
}
else
{
return
stage
;
}
}
else
{
}
else
{
ArchitectureElementSymbol
resolvedElement
=
architectureElement
.
getResolvedThis
().
get
();
return
name
(
resolvedElement
,
stage
,
streamIndices
);
}
...
...
@@ -78,8 +75,7 @@ public class LayerNameCreator {
streamIndices
.
remove
(
lastIndex
);
return
Collections
.
max
(
endStages
)
+
1
;
}
else
{
}
else
{
int
endStage
=
stage
;
for
(
ArchitectureElementSymbol
subElement
:
compositeElement
.
getElements
()){
endStage
=
name
(
subElement
,
endStage
,
streamIndices
);
...
...
@@ -113,8 +109,7 @@ public class LayerNameCreator {
name
=
name
+
"_"
+
arrayAccess
+
"_"
;
}
return
name
;
}
else
{
}
else
{
return
createBaseName
(
architectureElement
)
+
stage
+
createStreamPostfix
(
streamIndices
)
+
"_"
;
}
}
...
...
@@ -132,11 +127,9 @@ public class LayerNameCreator {
}
else
{
return
layerDeclaration
.
getName
().
toLowerCase
();
}
}
else
if
(
architectureElement
instanceof
CompositeElementSymbol
){
}
else
if
(
architectureElement
instanceof
CompositeElementSymbol
){
return
"group"
;
}
else
{
}
else
{
return
architectureElement
.
getName
();
}
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/TemplateConfiguration.java
View file @
8dbe6ce3
...
...
@@ -43,6 +43,11 @@ public class TemplateConfiguration {
configuration
.
setTemplateExceptionHandler
(
TemplateExceptionHandler
.
RETHROW_HANDLER
);
}
private
static
void
quitGeneration
(){
Log
.
error
(
"Code generation is aborted"
);
System
.
exit
(
1
);
}
public
Configuration
getConfiguration
()
{
return
configuration
;
}
...
...
@@ -58,14 +63,12 @@ public class TemplateConfiguration {
try
{
Template
template
=
TemplateConfiguration
.
get
().
getTemplate
(
templatePath
);
template
.
process
(
ftlContext
,
writer
);
}
catch
(
IOException
e
)
{
}
catch
(
IOException
e
)
{
Log
.
error
(
"Freemarker could not find template "
+
templatePath
+
" :\n"
+
e
.
getMessage
());
System
.
exit
(
1
);
}
catch
(
TemplateException
e
){
quitGeneration
();
}
catch
(
TemplateException
e
){
Log
.
error
(
"An exception occured in template "
+
templatePath
+
" :\n"
+
e
.
getMessage
());
System
.
exit
(
1
);
quitGeneration
(
);
}
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/TrainParamSupportChecker.java
View file @
8dbe6ce3
...
...
@@ -25,12 +25,14 @@ public class TrainParamSupportChecker implements CNNTrainVisitor {
public
TrainParamSupportChecker
()
{
}
public
String
unsupportedOptFlag
=
"unsupported_optimizer"
;
public
static
final
String
unsupportedOptFlag
=
"unsupported_optimizer"
;
public
List
getUnsupportedElemList
(){
return
this
.
unsupportedElemList
;
}
//Empty visit method denotes that the corresponding training parameter is supported.
//To set a training parameter as unsupported, add the corresponding node to the unsupportedElemList
public
void
visit
(
ASTNumEpochEntry
node
){}
public
void
visit
(
ASTBatchSizeEntry
node
){}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment