Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Submit feedback
Sign in
Toggle navigation
C
CNNArch2MXNet
Project overview
Project overview
Details
Activity
Releases
Cycle Analytics
Insights
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Locked Files
Issues
1
Issues
1
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Security & Compliance
Security & Compliance
Dependency List
Packages
Packages
Container Registry
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2MXNet
Commits
5f326fc2
Commit
5f326fc2
authored
Jun 16, 2019
by
Sebastian Nickels
Browse files
Options
Browse Files
Download
Plain Diff
Merge rnn into develop
parents
20d0b24c
9f5fdbf4
Pipeline
#150600
canceled with stages
Changes
75
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
75 changed files
with
345 additions
and
708 deletions
+345
-708
pom.xml
pom.xml
+9
-3
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/ArchitectureElementData.java
...ticar/cnnarch/mxnetgenerator/ArchitectureElementData.java
+8
-0
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/ArchitectureSupportChecker.java
...ar/cnnarch/mxnetgenerator/ArchitectureSupportChecker.java
+66
-0
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNArch2MxNet.java
...e/lang/monticar/cnnarch/mxnetgenerator/CNNArch2MxNet.java
+4
-8
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNArch2MxNetArchitectureSupportChecker.java
...netgenerator/CNNArch2MxNetArchitectureSupportChecker.java
+7
-0
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNArch2MxNetLayerSupportChecker.java
...arch/mxnetgenerator/CNNArch2MxNetLayerSupportChecker.java
+27
-0
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNArch2MxNetTemplateController.java
...narch/mxnetgenerator/CNNArch2MxNetTemplateController.java
+3
-4
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNArchSymbolCompiler.java
...onticar/cnnarch/mxnetgenerator/CNNArchSymbolCompiler.java
+11
-35
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNArchTemplateController.java
...car/cnnarch/mxnetgenerator/CNNArchTemplateController.java
+7
-2
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/LayerNameCreator.java
...ang/monticar/cnnarch/mxnetgenerator/LayerNameCreator.java
+4
-1
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/LayerSupportChecker.java
.../monticar/cnnarch/mxnetgenerator/LayerSupportChecker.java
+64
-0
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/MxNetTemplateConfiguration.java
...ar/cnnarch/mxnetgenerator/MxNetTemplateConfiguration.java
+0
-1
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/checker/AllowAllLayerSupportChecker.java
...h/mxnetgenerator/checker/AllowAllLayerSupportChecker.java
+0
-22
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/checker/LayerSupportChecker.java
...r/cnnarch/mxnetgenerator/checker/LayerSupportChecker.java
+0
-5
src/main/resources/templates/mxnet/CNNCreator.ftl
src/main/resources/templates/mxnet/CNNCreator.ftl
+1
-1
src/main/resources/templates/mxnet/elements/OneHot.ftl
src/main/resources/templates/mxnet/elements/OneHot.ftl
+3
-0
src/main/resources/templates/mxnet/elements/Output.ftl
src/main/resources/templates/mxnet/elements/Output.ftl
+3
-0
src/test/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/GenerationTest.java
.../lang/monticar/cnnarch/mxnetgenerator/GenerationTest.java
+31
-4
src/test/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/SymtabTest.java
...core/lang/monticar/cnnarch/mxnetgenerator/SymtabTest.java
+3
-3
src/test/resources/architectures/Alexnet.cnna
src/test/resources/architectures/Alexnet.cnna
+1
-1
src/test/resources/architectures/ResNeXt50.cnna
src/test/resources/architectures/ResNeXt50.cnna
+1
-1
src/test/resources/architectures/ResNet152.cnna
src/test/resources/architectures/ResNet152.cnna
+1
-1
src/test/resources/architectures/ResNet34.cnna
src/test/resources/architectures/ResNet34.cnna
+1
-1
src/test/resources/architectures/SequentialAlexnet.cnna
src/test/resources/architectures/SequentialAlexnet.cnna
+1
-1
src/test/resources/architectures/ThreeInputCNN_M14.cnna
src/test/resources/architectures/ThreeInputCNN_M14.cnna
+1
-1
src/test/resources/architectures/VGG16.cnna
src/test/resources/architectures/VGG16.cnna
+1
-1
src/test/resources/invalid_tests/ArgumentConstraintTest1.cnna
...test/resources/invalid_tests/ArgumentConstraintTest1.cnna
+0
-34
src/test/resources/invalid_tests/ArgumentConstraintTest2.cnna
...test/resources/invalid_tests/ArgumentConstraintTest2.cnna
+0
-34
src/test/resources/invalid_tests/ArgumentConstraintTest3.cnna
...test/resources/invalid_tests/ArgumentConstraintTest3.cnna
+0
-34
src/test/resources/invalid_tests/ArgumentConstraintTest4.cnna
...test/resources/invalid_tests/ArgumentConstraintTest4.cnna
+0
-34
src/test/resources/invalid_tests/ArgumentConstraintTest5.cnna
...test/resources/invalid_tests/ArgumentConstraintTest5.cnna
+0
-34
src/test/resources/invalid_tests/ArgumentConstraintTest6.cnna
...test/resources/invalid_tests/ArgumentConstraintTest6.cnna
+0
-34
src/test/resources/invalid_tests/DuplicatedArgument.cnna
src/test/resources/invalid_tests/DuplicatedArgument.cnna
+0
-11
src/test/resources/invalid_tests/DuplicatedIONames.cnna
src/test/resources/invalid_tests/DuplicatedIONames.cnna
+0
-16
src/test/resources/invalid_tests/DuplicatedNames.cnna
src/test/resources/invalid_tests/DuplicatedNames.cnna
+0
-20
src/test/resources/invalid_tests/IllegalIOName.cnna
src/test/resources/invalid_tests/IllegalIOName.cnna
+0
-11
src/test/resources/invalid_tests/IllegalName.cnna
src/test/resources/invalid_tests/IllegalName.cnna
+0
-15
src/test/resources/invalid_tests/InvalidArrayAccessValue.cnna
...test/resources/invalid_tests/InvalidArrayAccessValue.cnna
+0
-27
src/test/resources/invalid_tests/InvalidIOShape1.cnna
src/test/resources/invalid_tests/InvalidIOShape1.cnna
+0
-11
src/test/resources/invalid_tests/InvalidIOShape2.cnna
src/test/resources/invalid_tests/InvalidIOShape2.cnna
+0
-11
src/test/resources/invalid_tests/InvalidInputShape.cnna
src/test/resources/invalid_tests/InvalidInputShape.cnna
+0
-11
src/test/resources/invalid_tests/InvalidRecursion.cnna
src/test/resources/invalid_tests/InvalidRecursion.cnna
+0
-40
src/test/resources/invalid_tests/MissingArgument.cnna
src/test/resources/invalid_tests/MissingArgument.cnna
+0
-34
src/test/resources/invalid_tests/MissingIO2.cnna
src/test/resources/invalid_tests/MissingIO2.cnna
+0
-11
src/test/resources/invalid_tests/MissingLayerOperator.cnna
src/test/resources/invalid_tests/MissingLayerOperator.cnna
+0
-11
src/test/resources/invalid_tests/MissingMerge.cnna
src/test/resources/invalid_tests/MissingMerge.cnna
+0
-24
src/test/resources/invalid_tests/MissingParallelBrackets.cnna
...test/resources/invalid_tests/MissingParallelBrackets.cnna
+0
-45
src/test/resources/invalid_tests/MultipleOutputs.cnna
src/test/resources/invalid_tests/MultipleOutputs.cnna
+1
-1
src/test/resources/invalid_tests/MultipleStreams.cnna
src/test/resources/invalid_tests/MultipleStreams.cnna
+14
-0
src/test/resources/invalid_tests/NotIOArray.cnna
src/test/resources/invalid_tests/NotIOArray.cnna
+0
-11
src/test/resources/invalid_tests/UnfinishedArchitecture.cnna
src/test/resources/invalid_tests/UnfinishedArchitecture.cnna
+0
-15
src/test/resources/invalid_tests/UnknownIO.cnna
src/test/resources/invalid_tests/UnknownIO.cnna
+0
-8
src/test/resources/invalid_tests/UnknownMethod.cnna
src/test/resources/invalid_tests/UnknownMethod.cnna
+0
-11
src/test/resources/invalid_tests/UnknownVariableName.cnna
src/test/resources/invalid_tests/UnknownVariableName.cnna
+0
-11
src/test/resources/invalid_tests/WrongArgument.cnna
src/test/resources/invalid_tests/WrongArgument.cnna
+0
-11
src/test/resources/invalid_tests/WrongIOType.cnna
src/test/resources/invalid_tests/WrongIOType.cnna
+0
-11
src/test/resources/invalid_tests/WrongRangeOperator.cnna
src/test/resources/invalid_tests/WrongRangeOperator.cnna
+0
-11
src/test/resources/invalid_tests/data_paths.txt
src/test/resources/invalid_tests/data_paths.txt
+2
-3
src/test/resources/valid_tests/Alexnet_alt.cnna
src/test/resources/valid_tests/Alexnet_alt.cnna
+1
-1
src/test/resources/valid_tests/Alexnet_alt2.cnna
src/test/resources/valid_tests/Alexnet_alt2.cnna
+1
-1
src/test/resources/valid_tests/Alexnet_alt_OneHotOutput.cnna
src/test/resources/valid_tests/Alexnet_alt_OneHotOutput.cnna
+54
-0
src/test/resources/valid_tests/ArgumentSequenceTest.cnna
src/test/resources/valid_tests/ArgumentSequenceTest.cnna
+1
-1
src/test/resources/valid_tests/CifarClassifierNetwork.cnna
src/test/resources/valid_tests/CifarClassifierNetwork.cnna
+1
-1
src/test/resources/valid_tests/Fixed_Alexnet.cnna
src/test/resources/valid_tests/Fixed_Alexnet.cnna
+1
-1
src/test/resources/valid_tests/Fixed_ThreeInputCNN_M14.cnna
src/test/resources/valid_tests/Fixed_ThreeInputCNN_M14.cnna
+1
-1
src/test/resources/valid_tests/ResNeXt50_InstanceTest.cnna
src/test/resources/valid_tests/ResNeXt50_InstanceTest.cnna
+1
-1
src/test/resources/valid_tests/ResNeXt50_alt.cnna
src/test/resources/valid_tests/ResNeXt50_alt.cnna
+1
-1
src/test/resources/valid_tests/ResNet152_alt.cnna
src/test/resources/valid_tests/ResNet152_alt.cnna
+1
-1
src/test/resources/valid_tests/SimpleNetworkLinear.cnna
src/test/resources/valid_tests/SimpleNetworkLinear.cnna
+1
-1
src/test/resources/valid_tests/SimpleNetworkRelu.cnna
src/test/resources/valid_tests/SimpleNetworkRelu.cnna
+1
-1
src/test/resources/valid_tests/SimpleNetworkSigmoid.cnna
src/test/resources/valid_tests/SimpleNetworkSigmoid.cnna
+1
-1
src/test/resources/valid_tests/SimpleNetworkSoftmax.cnna
src/test/resources/valid_tests/SimpleNetworkSoftmax.cnna
+1
-1
src/test/resources/valid_tests/SimpleNetworkTanh.cnna
src/test/resources/valid_tests/SimpleNetworkTanh.cnna
+1
-1
src/test/resources/valid_tests/ThreeInputCNN_M14_alternative.cnna
.../resources/valid_tests/ThreeInputCNN_M14_alternative.cnna
+1
-1
src/test/resources/valid_tests/data_paths.txt
src/test/resources/valid_tests/data_paths.txt
+1
-2
No files found.
pom.xml
View file @
5f326fc2
...
...
@@ -8,15 +8,15 @@
<groupId>
de.monticore.lang.monticar
</groupId>
<artifactId>
cnnarch-mxnet-generator
</artifactId>
<version>
0.2.1
5
-SNAPSHOT
</version>
<version>
0.2.1
6
-SNAPSHOT
</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>
0.3.
0
-SNAPSHOT
</CNNArch.version>
<CNNTrain.version>
0.
2.6
</CNNTrain.version>
<CNNArch.version>
0.3.
1
-SNAPSHOT
</CNNArch.version>
<CNNTrain.version>
0.
3.2-SNAPSHOT
</CNNTrain.version>
<embedded-montiarc-math-opt-generator>
0.1.4
</embedded-montiarc-math-opt-generator>
<!-- .. Libraries .................................................. -->
...
...
@@ -102,6 +102,12 @@
<scope>
test
</scope>
</dependency>
<dependency>
<groupId>
com.github.stefanbirkner
</groupId>
<artifactId>
system-rules
</artifactId>
<version>
1.3.0
</version>
</dependency>
<dependency>
<groupId>
ch.qos.logback
</groupId>
<artifactId>
logback-classic
</artifactId>
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/ArchitectureElementData.java
View file @
5f326fc2
...
...
@@ -90,6 +90,9 @@ public class ArchitectureElementData {
return
getTemplateController
().
isSoftmaxOutput
(
getElement
());
}
public
boolean
isOneHotOutput
(){
return
getTemplateController
().
isOneHotOutput
(
getElement
());
}
...
...
@@ -158,6 +161,11 @@ public class ArchitectureElementData {
.
getDoubleValue
(
AllPredefinedLayers
.
BETA_NAME
).
get
();
}
public
int
getSize
(){
return
((
LayerSymbol
)
getElement
())
.
getIntValue
(
AllPredefinedLayers
.
SIZE_NAME
).
get
();
}
@Nullable
public
String
getPoolType
(){
return
((
LayerSymbol
)
getElement
())
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/ArchitectureSupportChecker.java
0 → 100644
View file @
5f326fc2
package
de
.
monticore
.
lang
.
monticar
.
cnnarch
.
mxnetgenerator
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol
;
import
de.se_rwth.commons.logging.Log
;
public
class
ArchitectureSupportChecker
{
public
ArchitectureSupportChecker
()
{}
// Overload functions returning always true to enable the features
protected
boolean
checkMultipleStreams
(
ArchitectureSymbol
architecture
)
{
if
(
architecture
.
getStreams
().
size
()
!=
1
)
{
Log
.
error
(
"This cnn architecture has multiple instructions, "
+
"which is currently not supported by the code generator. "
,
architecture
.
getSourcePosition
());
return
false
;
}
return
true
;
}
protected
boolean
checkMultipleInputs
(
ArchitectureSymbol
architecture
)
{
if
(
architecture
.
getInputs
().
size
()
>
1
)
{
Log
.
error
(
"This cnn architecture has multiple inputs, "
+
"which is currently not supported by the code generator. "
,
architecture
.
getSourcePosition
());
return
false
;
}
return
true
;
}
protected
boolean
checkMultipleOutputs
(
ArchitectureSymbol
architecture
)
{
if
(
architecture
.
getOutputs
().
size
()
>
1
)
{
Log
.
error
(
"This cnn architecture has multiple outputs, "
+
"which is currently not supported by the code generator. "
,
architecture
.
getSourcePosition
());
return
false
;
}
return
true
;
}
protected
boolean
checkMultiDimensionalOutput
(
ArchitectureSymbol
architecture
)
{
if
(
architecture
.
getOutputs
().
get
(
0
).
getDefinition
().
getType
().
getWidth
()
!=
1
||
architecture
.
getOutputs
().
get
(
0
).
getDefinition
().
getType
().
getHeight
()
!=
1
)
{
Log
.
error
(
"This cnn architecture has a multi-dimensional output, "
+
"which is currently not supported by the code generator."
,
architecture
.
getSourcePosition
());
return
false
;
}
return
true
;
}
public
boolean
check
(
ArchitectureSymbol
architecture
)
{
return
checkMultipleStreams
(
architecture
)
&&
checkMultipleInputs
(
architecture
)
&&
checkMultipleOutputs
(
architecture
)
&&
checkMultiDimensionalOutput
(
architecture
);
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNArch2MxNet.java
View file @
5f326fc2
...
...
@@ -23,7 +23,6 @@ package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import
de.monticore.lang.monticar.cnnarch.CNNArchGenerator
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol
;
import
de.monticore.lang.monticar.cnnarch.DataPathConfigParser
;
import
de.monticore.lang.monticar.cnnarch.mxnetgenerator.checker.AllowAllLayerSupportChecker
;
import
de.monticore.lang.monticar.generator.FileContent
;
import
de.monticore.lang.monticar.generator.cmake.CMakeConfig
;
import
de.monticore.lang.monticar.generator.cmake.CMakeFindModule
;
...
...
@@ -36,12 +35,14 @@ import java.util.HashMap;
import
java.util.Map
;
public
class
CNNArch2MxNet
extends
CNNArchGenerator
{
public
CNNArch2MxNet
()
{
setGenerationTargetPath
(
"./target/generated-sources-cnnarch/"
);
}
public
void
generate
(
Scope
scope
,
String
rootModelName
){
CNNArchSymbolCompiler
symbolCompiler
=
new
CNNArchSymbolCompiler
(
new
AllowAllLayerSupportChecker
());
CNNArchSymbolCompiler
symbolCompiler
=
new
CNNArchSymbolCompiler
(
new
CNNArch2MxNetArchitectureSupportChecker
(),
new
CNNArch2MxNetLayerSupportChecker
());
ArchitectureSymbol
architectureSymbol
=
symbolCompiler
.
compileArchitectureSymbol
(
scope
,
rootModelName
);
try
{
...
...
@@ -58,11 +59,8 @@ public class CNNArch2MxNet extends CNNArchGenerator {
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public
Map
<
String
,
String
>
generateStrings
(
ArchitectureSymbol
architecture
){
TemplateConfiguration
templateConfiguration
=
new
MxNetTemplateConfiguration
();
Map
<
String
,
String
>
fileContentMap
=
new
HashMap
<>();
CNNArch2MxNetTemplateController
archTc
=
new
CNNArch2MxNetTemplateController
(
architecture
,
templateConfiguration
);
CNNArch2MxNetTemplateController
archTc
=
new
CNNArch2MxNetTemplateController
(
architecture
);
Map
.
Entry
<
String
,
String
>
temp
;
temp
=
archTc
.
process
(
"CNNPredictor"
,
Target
.
CPP
);
...
...
@@ -77,8 +75,6 @@ public class CNNArch2MxNet extends CNNArchGenerator {
temp
=
archTc
.
process
(
"CNNBufferFile"
,
Target
.
CPP
);
fileContentMap
.
put
(
"CNNBufferFile.h"
,
temp
.
getValue
());
checkValidGeneration
(
architecture
);
return
fileContentMap
;
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNArch2MxNetArchitectureSupportChecker.java
0 → 100644
View file @
5f326fc2
package
de
.
monticore
.
lang
.
monticar
.
cnnarch
.
mxnetgenerator
;
public
class
CNNArch2MxNetArchitectureSupportChecker
extends
ArchitectureSupportChecker
{
public
CNNArch2MxNetArchitectureSupportChecker
()
{}
}
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNArch2MxNetLayerSupportChecker.java
0 → 100644
View file @
5f326fc2
package
de
.
monticore
.
lang
.
monticar
.
cnnarch
.
mxnetgenerator
;
import
de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers
;
public
class
CNNArch2MxNetLayerSupportChecker
extends
LayerSupportChecker
{
public
CNNArch2MxNetLayerSupportChecker
()
{
supportedLayerList
.
add
(
AllPredefinedLayers
.
FULLY_CONNECTED_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
CONVOLUTION_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
SOFTMAX_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
SIGMOID_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
TANH_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
RELU_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
DROPOUT_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
POOLING_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
GLOBAL_POOLING_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
LRN_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
BATCHNORM_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
SPLIT_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
GET_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
ADD_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
CONCATENATE_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
FLATTEN_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
ONE_HOT_NAME
);
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNArch2MxNetTemplateController.java
View file @
5f326fc2
...
...
@@ -9,9 +9,8 @@ import java.io.Writer;
*/
public
class
CNNArch2MxNetTemplateController
extends
CNNArchTemplateController
{
public
CNNArch2MxNetTemplateController
(
ArchitectureSymbol
architecture
,
TemplateConfiguration
templateConfiguration
)
{
super
(
architecture
,
templateConfiguration
);
public
CNNArch2MxNetTemplateController
(
ArchitectureSymbol
architecture
)
{
super
(
architecture
,
new
MxNetTemplateConfiguration
());
}
public
void
include
(
IOSymbol
ioElement
,
Writer
writer
){
...
...
@@ -37,7 +36,7 @@ public class CNNArch2MxNetTemplateController extends CNNArchTemplateController {
if
(
layer
.
isAtomic
()){
ArchitectureElementSymbol
nextElement
=
layer
.
getOutputElement
().
get
();
if
(!
isSoftmaxOutput
(
nextElement
)
&&
!
isLogisticRegressionOutput
(
nextElement
)){
if
(!
isSoftmaxOutput
(
nextElement
)
&&
!
isLogisticRegressionOutput
(
nextElement
)
&&
!
isOneHotOutput
(
nextElement
)
){
String
templateName
=
layer
.
getDeclaration
().
getName
();
include
(
TEMPLATE_ELEMENTS_DIR_PATH
,
templateName
,
writer
);
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNArchSymbolCompiler.java
View file @
5f326fc2
...
...
@@ -3,7 +3,6 @@ package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import
de.monticore.io.paths.ModelPath
;
import
de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos
;
import
de.monticore.lang.monticar.cnnarch._symboltable.*
;
import
de.monticore.lang.monticar.cnnarch.mxnetgenerator.checker.LayerSupportChecker
;
import
de.monticore.symboltable.GlobalScope
;
import
de.monticore.symboltable.Scope
;
import
de.se_rwth.commons.logging.Log
;
...
...
@@ -13,10 +12,13 @@ import java.util.List;
import
java.util.Optional
;
public
class
CNNArchSymbolCompiler
{
private
final
LayerSupportChecker
layerChecker
;
private
final
ArchitectureSupportChecker
architectureSupportChecker
;
private
final
LayerSupportChecker
layerSupportChecker
;
public
CNNArchSymbolCompiler
(
final
LayerSupportChecker
layerChecker
)
{
this
.
layerChecker
=
layerChecker
;
public
CNNArchSymbolCompiler
(
final
ArchitectureSupportChecker
architectureSupportChecker
,
final
LayerSupportChecker
layerSupportChecker
)
{
this
.
architectureSupportChecker
=
architectureSupportChecker
;
this
.
layerSupportChecker
=
layerSupportChecker
;
}
public
ArchitectureSymbol
compileArchitectureSymbolFromModelsDir
(
...
...
@@ -30,47 +32,21 @@ public class CNNArchSymbolCompiler {
Optional
<
CNNArchCompilationUnitSymbol
>
compilationUnit
=
scope
.
resolve
(
rootModelName
,
CNNArchCompilationUnitSymbol
.
KIND
);
if
(!
compilationUnit
.
isPresent
()){
failWithMessage
(
"Could not resolve architecture "
+
rootModelName
);
}
CNNArchCocos
.
checkAll
(
compilationUnit
.
get
());
if
(!
supportCheck
(
compilationUnit
.
get
().
getArchitecture
())){
ArchitectureSymbol
architecture
=
compilationUnit
.
get
().
getArchitecture
();
if
(!
architectureSupportChecker
.
check
(
architecture
)
||
!
layerSupportChecker
.
check
(
architecture
))
{
failWithMessage
(
"Architecture not supported by generator"
);
}
return
compilationUnit
.
get
().
getArchitecture
()
;
return
architecture
;
}
private
void
failWithMessage
(
final
String
message
)
{
Log
.
error
(
message
);
System
.
exit
(
1
);
}
private
boolean
supportCheck
(
ArchitectureSymbol
architecture
){
for
(
ArchitectureElementSymbol
element
:
((
CompositeElementSymbol
)
architecture
.
getBody
()).
getElements
()){
if
(!
isSupportedLayer
(
element
,
layerChecker
))
{
return
false
;
}
}
return
true
;
}
private
boolean
isSupportedLayer
(
ArchitectureElementSymbol
element
,
LayerSupportChecker
layerChecker
){
List
<
ArchitectureElementSymbol
>
constructLayerElemList
;
if
(
element
.
getResolvedThis
().
get
()
instanceof
CompositeElementSymbol
)
{
constructLayerElemList
=
((
CompositeElementSymbol
)
element
.
getResolvedThis
().
get
()).
getElements
();
for
(
ArchitectureElementSymbol
constructedLayerElement
:
constructLayerElemList
)
{
if
(!
isSupportedLayer
(
constructedLayerElement
,
layerChecker
))
{
return
false
;
}
}
}
if
(!
layerChecker
.
isSupported
(
element
.
toString
()))
{
Log
.
error
(
"Unsupported layer "
+
"'"
+
element
.
getName
()
+
"'"
+
" for the backend."
);
return
false
;
}
else
{
return
true
;
}
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNArchTemplateController.java
View file @
5f326fc2
...
...
@@ -23,6 +23,7 @@ package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import
de.monticore.lang.monticar.cnnarch._symboltable.*
;
import
de.monticore.lang.monticar.cnnarch.predefined.Sigmoid
;
import
de.monticore.lang.monticar.cnnarch.predefined.Softmax
;
import
de.monticore.lang.monticar.cnnarch.predefined.OneHot
;
import
java.io.StringWriter
;
import
java.io.Writer
;
...
...
@@ -139,7 +140,7 @@ public abstract class CNNArchTemplateController {
public
List
<
String
>
getLayerInputs
(
ArchitectureElementSymbol
layer
){
List
<
String
>
inputNames
=
new
ArrayList
<>();
if
(
isSoftmaxOutput
(
layer
)
||
isLogisticRegressionOutput
(
layer
)){
if
(
isSoftmaxOutput
(
layer
)
||
isLogisticRegressionOutput
(
layer
)
||
isOneHotOutput
(
layer
)
){
inputNames
=
getLayerInputs
(
layer
.
getInputElement
().
get
());
}
else
{
for
(
ArchitectureElementSymbol
input
:
layer
.
getPrevious
())
{
...
...
@@ -228,9 +229,13 @@ public abstract class CNNArchTemplateController {
public
boolean
isLinearRegressionOutput
(
ArchitectureElementSymbol
architectureElement
){
return
architectureElement
.
isOutput
()
&&
!
isLogisticRegressionOutput
(
architectureElement
)
&&
!
isSoftmaxOutput
(
architectureElement
);
&&
!
isSoftmaxOutput
(
architectureElement
)
&&
!
isOneHotOutput
(
architectureElement
);
}
public
boolean
isOneHotOutput
(
ArchitectureElementSymbol
architectureElement
){
return
isTOutput
(
OneHot
.
class
,
architectureElement
);
}
public
boolean
isSoftmaxOutput
(
ArchitectureElementSymbol
architectureElement
){
return
isTOutput
(
Softmax
.
class
,
architectureElement
);
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/LayerNameCreator.java
View file @
5f326fc2
...
...
@@ -33,7 +33,10 @@ public class LayerNameCreator {
private
Map
<
String
,
ArchitectureElementSymbol
>
nameToElement
=
new
HashMap
<>();
public
LayerNameCreator
(
ArchitectureSymbol
architecture
)
{
name
(
architecture
.
getBody
(),
1
,
new
ArrayList
<>());
int
stage
=
1
;
for
(
CompositeElementSymbol
stream
:
architecture
.
getStreams
())
{
stage
=
name
(
stream
,
stage
,
new
ArrayList
<>());
}
}
public
ArchitectureElementSymbol
getArchitectureElement
(
String
name
){
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/LayerSupportChecker.java
0 → 100644
View file @
5f326fc2
package
de
.
monticore
.
lang
.
monticar
.
cnnarch
.
mxnetgenerator
;
import
de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol
;
import
de.se_rwth.commons.logging.Log
;
import
java.util.ArrayList
;
import
java.util.List
;
public
abstract
class
LayerSupportChecker
{
protected
List
<
String
>
supportedLayerList
=
new
ArrayList
<>();
private
boolean
isSupportedLayer
(
ArchitectureElementSymbol
element
){
ArchitectureElementSymbol
resolvedElement
=
element
.
getResolvedThis
().
get
();
List
<
ArchitectureElementSymbol
>
constructLayerElemList
;
if
(
resolvedElement
instanceof
CompositeElementSymbol
)
{
constructLayerElemList
=
((
CompositeElementSymbol
)
resolvedElement
).
getElements
();
for
(
ArchitectureElementSymbol
constructedLayerElement
:
constructLayerElemList
)
{
if
(!
isSupportedLayer
(
constructedLayerElement
))
{
return
false
;
}
}
return
true
;
}
// Support all inputs and outputs
if
(
resolvedElement
.
isInput
()
||
resolvedElement
.
isOutput
())
{
return
true
;
}
// Support all layer declarations
if
(
resolvedElement
instanceof
LayerSymbol
)
{
if
(!((
LayerSymbol
)
resolvedElement
).
getDeclaration
().
isPredefined
())
{
return
true
;
}
}
if
(!
supportedLayerList
.
contains
(
element
.
toString
()))
{
Log
.
error
(
"Unsupported layer "
+
"'"
+
element
.
getName
()
+
"'"
+
" for the current backend."
);
return
false
;
}
else
{
return
true
;
}
}
public
boolean
check
(
ArchitectureSymbol
architecture
)
{
for
(
CompositeElementSymbol
stream
:
architecture
.
getStreams
())
{
for
(
ArchitectureElementSymbol
element
:
stream
.
getElements
())
{
if
(!
isSupportedLayer
(
element
))
{
return
false
;
}
}
}
return
true
;
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/MxNetTemplateConfiguration.java
View file @
5f326fc2
...
...
@@ -9,7 +9,6 @@ public class MxNetTemplateConfiguration extends TemplateConfiguration {
private
static
Configuration
configuration
;
public
MxNetTemplateConfiguration
()
{
super
();
if
(
configuration
==
null
)
{
configuration
=
super
.
createConfiguration
();
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/checker/AllowAllLayerSupportChecker.java
deleted
100644 → 0
View file @
20d0b24c
package
de
.
monticore
.
lang
.
monticar
.
cnnarch
.
mxnetgenerator
.
checker
;
import
de.monticore.lang.monticar.cnnarch.mxnetgenerator.checker.LayerSupportChecker
;
import
java.util.ArrayList
;
import
java.util.List
;
public
class
AllowAllLayerSupportChecker
implements
LayerSupportChecker
{
private
List
<
String
>
unsupportedLayerList
=
new
ArrayList
<>();
public
AllowAllLayerSupportChecker
()
{
//Set the unsupported layers for the backend
//this.unsupportedLayerList.add(PREDEFINED_LAYER_NAME);
}
@Override
public
boolean
isSupported
(
String
element
)
{
return
!
this
.
unsupportedLayerList
.
contains
(
element
);
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/checker/LayerSupportChecker.java
deleted
100644 → 0
View file @
20d0b24c
package
de
.
monticore
.
lang
.
monticar
.
cnnarch
.
mxnetgenerator
.
checker
;
public
interface
LayerSupportChecker
{
boolean
isSupported
(
String
element
);
}
\ No newline at end of file
src/main/resources/templates/mxnet/CNNCreator.ftl
View file @
5f326fc2
...
...
@@ -172,7 +172,7 @@ class ${tc.fileNameWithoutEnding}:
def construct(self, context, data_mean=None, data_std=None):
${tc.include(tc.architecture.
body
)}
${tc.include(tc.architecture.
streams[0]
)}
self.module = mx.mod.Module(symbol=mx.symbol.Group([${tc.join(tc.architectureOutputs, ",")}]),
data_names=self._input_names_,
label_names=self._output_names_,
...
...
src/main/resources/templates/mxnet/elements/OneHot.ftl
0 → 100644
View file @
5f326fc2
${element.name} = mx.symbol.one_hot(data=${element.inputs[0]},
indices=mx.symbol.argmax(data=${element.inputs[0]}, axis=1), depth=${element.size}))
<#include "OutputShape.ftl">
\ No newline at end of file
src/main/resources/templates/mxnet/elements/Output.ftl
View file @
5f326fc2
...
...
@@ -8,4 +8,7 @@
<#elseif element.linearRegressionOutput>
${element.name} = mx.symbol.LinearRegressionOutput(data=${element.inputs[0]},
name="${element.name}")
<#elseif element.oneHotOutput>
${element.name} = mx.symbol.SoftmaxOutput(data=${element.inputs[0]},
name="${element.name}")
</#if>
\ No newline at end of file
src/test/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/GenerationTest.java
View file @
5f326fc2
...
...
@@ -23,6 +23,7 @@ package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import
de.se_rwth.commons.logging.Log
;
import
freemarker.template.TemplateException
;
import
org.junit.Before
;
import
org.junit.Rule
;
import
org.junit.Test
;
import
java.io.IOException
;
...
...
@@ -30,9 +31,13 @@ import java.nio.file.Path;
import
java.nio.file.Paths
;
import
java.util.*
;
import
org.junit.contrib.java.lang.system.Assertion
;
import
org.junit.contrib.java.lang.system.ExpectedSystemExit
;
import
static
junit
.
framework
.
TestCase
.
assertTrue
;
public
class
GenerationTest
extends
AbstractSymtabTest
{
@Rule
public
final
ExpectedSystemExit
exit
=
ExpectedSystemExit
.
none
();
@Before
public
void
setUp
()
{
...
...
@@ -90,13 +95,17 @@ public class GenerationTest extends AbstractSymtabTest{
"execute_VGG16"
));
}
@Test
public
void
testThreeInputCNNGeneration
()
throws
IOException
,
TemplateException
{
Log
.
getFindings
().
clear
();
String
[]
args
=
{
"-m"
,
"src/test/resources/architectures"
,
"-r"
,
"ThreeInputCNN_M14"
};
exit
.
expectSystemExit
();
exit
.
checkAssertionAfterwards
(
new
Assertion
()
{
public
void
checkAssertion
()
{
assertTrue
(
Log
.
getFindings
().
size
()
==
2
);
}
});
CNNArch2MxNetCli
.
main
(
args
);
assertTrue
(
Log
.
getFindings
().
size
()
==
1
);
}
@Test
...
...
@@ -107,12 +116,30 @@ public class GenerationTest extends AbstractSymtabTest{
assertTrue
(
Log
.
getFindings
().
isEmpty
());
}
@Test
public
void
testMultipleStreams
()
throws
IOException
,
TemplateException
{
Log
.
getFindings
().
clear
();
String
[]
args
=
{
"-m"
,
"src/test/resources/invalid_tests"
,
"-r"
,
"MultipleStreams"
};
exit
.
expectSystemExit
();
exit
.
checkAssertionAfterwards
(
new
Assertion
()
{
public
void
checkAssertion
()
{
assertTrue
(
Log
.
getFindings
().
size
()
==
2
);
}
});
CNNArch2MxNetCli
.
main
(
args
);
}
@Test
public
void
testMultipleOutputs
()
throws
IOException
,
TemplateException
{
Log
.
getFindings
().
clear
();
String
[]
args
=
{
"-m"
,
"src/test/resources/valid_tests"
,
"-r"
,
"MultipleOutputs"
};
String
[]
args
=
{
"-m"
,
"src/test/resources/invalid_tests"
,
"-r"
,
"MultipleOutputs"
};
exit
.
expectSystemExit
();
exit
.
checkAssertionAfterwards
(
new
Assertion
()
{
public
void
checkAssertion
()
{
assertTrue
(
Log
.
getFindings
().
size
()
==
2
);
}
});
CNNArch2MxNetCli
.
main
(
args
);
assertTrue
(
Log
.
getFindings
().
size
()
==
3
);
}
@Test
...
...
src/test/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/SymtabTest.java
View file @
5f326fc2
...
...
@@ -55,7 +55,7 @@ public class SymtabTest extends AbstractSymtabTest {
CNNArchCompilationUnitSymbol
.
KIND
).
orElse
(
null
);
assertNotNull
(
a
);
a
.
resolve
();
a
.
getArchitecture
().
get
Body
(
).
getOutputTypes
();
a
.
getArchitecture
().
get
Streams
().
get
(
0
).
getOutputTypes
();
}
@Ignore
...
...
@@ -67,7 +67,7 @@ public class SymtabTest extends AbstractSymtabTest {
CNNArchCompilationUnitSymbol
.
KIND
).
orElse
(
null
);
assertNotNull
(
a
);
a
.
resolve
();
a
.
getArchitecture
().
get
Body
(
).
getOutputTypes
();
a
.
getArchitecture
().
get
Streams
().
get
(
0
).
getOutputTypes
();
}
@Ignore
...
...
@@ -79,7 +79,7 @@ public class SymtabTest extends AbstractSymtabTest {
CNNArchCompilationUnitSymbol
.
KIND
).
orElse
(
null
);
assertNotNull
(
a
);
a
.
resolve
();
a
.
getArchitecture
().
get
Body
(
).
getOutputTypes
();
a
.
getArchitecture
().
get
Streams
().
get
(
0
).
getOutputTypes
();
}
}
src/test/resources/architectures/Alexnet.cnna
View file @
5f326fc2
...
...
@@ -39,5 +39,5 @@ architecture Alexnet(img_height=224, img_width=224, img_channels=3, classes=10){
fc(->=2) ->
FullyConnected(units=10) ->
Softmax() ->
predictions
predictions
;
}
\ No newline at end of file
src/test/resources/architectures/ResNeXt50.cnna
View file @
5f326fc2
...
...
@@ -40,5 +40,5 @@ architecture ResNeXt50(img_height=224, img_width=224, img_channels=3, classes=10
GlobalPooling(pool_type="avg") ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions
;
}
\ No newline at end of file
src/test/resources/architectures/ResNet152.cnna
View file @
5f326fc2
...
...
@@ -33,5 +33,5 @@ architecture ResNet152(img_height=224, img_width=224, img_channels=3, classes=10
GlobalPooling(pool_type="avg") ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions
;
}
\ No newline at end of file
src/test/resources/architectures/ResNet34.cnna
View file @
5f326fc2
...
...
@@ -31,5 +31,5 @@ architecture ResNet34(img_height=224, img_width=224, img_channels=3, classes=100
GlobalPooling(pool_type="avg") ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions
;
}
src/test/resources/architectures/SequentialAlexnet.cnna
View file @
5f326fc2
...
...
@@ -25,5 +25,5 @@ architecture SequentialAlexnet(img_height=224, img_width=224, img_channels=3, cl
fc() ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions
;
}
src/test/resources/architectures/ThreeInputCNN_M14.cnna
View file @
5f326fc2
...
...
@@ -28,5 +28,5 @@ architecture ThreeInputCNN_M14(img_height=200, img_width=300, img_channels=3, cl
Relu() ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions
;
}
\ No newline at end of file
src/test/resources/architectures/VGG16.cnna
View file @
5f326fc2
...
...
@@ -27,5 +27,5 @@ architecture VGG16(img_height=224, img_width=224, img_channels=3, classes=1000){
fc() ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions
;
}
\ No newline at end of file
src/test/resources/invalid_tests/ArgumentConstraintTest1.cnna
deleted
100644 → 0
View file @
20d0b24c
architecture ArgumentConstraintTest1(img_height=224, img_width=224, img_channels=3, classes=1000){
def input Z(0:255)^{img_channels, img_height, img_width} image
def output Q(0:1)^{classes} predictions
def conv(kernel, channels, stride=1, act=true){
Convolution(kernel=(kernel,kernel), channels=channels, stride=(stride,stride)) ->
BatchNorm() ->
Relu(?=act)
}
def skip(channels, stride){
Convolution(kernel=(1,1), channels=75, stride=(stride,stride)) ->
BatchNorm()
}
def resLayer(channels, stride=1){
(
conv(kernel=3, channels=channels, stride=stride) ->