Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
What's new
7
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2X
Commits
b14b8540
Commit
b14b8540
authored
Dec 04, 2019
by
Evgeny Kusmenko
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'develop' into 'master'
Added Unroll-related work and support for new layers See merge request
!7
parents
bbe7e4a4
2ed7b07c
Pipeline
#214600
passed with stages
in 4 minutes and 53 seconds
Changes
7
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
156 additions
and
71 deletions
+156
-71
pom.xml
pom.xml
+4
-3
src/main/java/de/monticore/lang/monticar/cnnarch/generator/ArchitectureElementData.java
...g/monticar/cnnarch/generator/ArchitectureElementData.java
+16
-10
src/main/java/de/monticore/lang/monticar/cnnarch/generator/ArchitectureSupportChecker.java
...onticar/cnnarch/generator/ArchitectureSupportChecker.java
+21
-7
src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchTemplateController.java
...monticar/cnnarch/generator/CNNArchTemplateController.java
+34
-1
src/main/java/de/monticore/lang/monticar/cnnarch/generator/ConfigurationData.java
...re/lang/monticar/cnnarch/generator/ConfigurationData.java
+46
-5
src/main/java/de/monticore/lang/monticar/cnnarch/generator/LayerNameCreator.java
...ore/lang/monticar/cnnarch/generator/LayerNameCreator.java
+30
-38
src/main/java/de/monticore/lang/monticar/cnnarch/generator/LayerSupportChecker.java
.../lang/monticar/cnnarch/generator/LayerSupportChecker.java
+5
-7
No files found.
pom.xml
View file @
b14b8540
...
...
@@ -9,15 +9,16 @@
<groupId>
de.monticore.lang.monticar
</groupId>
<artifactId>
cnnarch-generator
</artifactId>
<version>
0.0.
4
-SNAPSHOT
</version>
<version>
0.0.
5
-SNAPSHOT
</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>
0.3.3-SNAPSHOT
</CNNArch.version>
<CNNTrain.version>
0.3.6-SNAPSHOT
</CNNTrain.version>
<CNNArch.version>
0.3.4-SNAPSHOT
</CNNArch.version>
<CNNTrain.version>
0.3.9-SNAPSHOT
</CNNTrain.version>
<embedded-montiarc-math-opt-generator>
0.1.4
</embedded-montiarc-math-opt-generator>
<!-- .. Libraries .................................................. -->
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/generator/ArchitectureElementData.java
View file @
b14b8540
...
...
@@ -79,12 +79,6 @@ public class ArchitectureElementData {
}
public
int
getConstValue
()
{
assert
getElement
()
instanceof
ConstantSymbol
;
return
((
ConstantSymbol
)
getElement
()).
getExpression
().
getIntValue
().
get
();
}
public
List
<
Integer
>
getKernel
(){
return
getLayerSymbol
().
getIntTupleValue
(
AllPredefinedLayers
.
KERNEL_NAME
).
get
();
}
...
...
@@ -141,6 +135,22 @@ public class ArchitectureElementData {
return
getLayerSymbol
().
getIntValue
(
AllPredefinedLayers
.
SIZE_NAME
).
get
();
}
public
int
getRepeats
(){
return
getLayerSymbol
().
getIntValue
(
AllPredefinedLayers
.
REPEATS_NAME
).
get
();
}
public
int
getAxis
(){
return
getLayerSymbol
().
getIntValue
(
AllPredefinedLayers
.
AXIS_NAME
).
get
();
}
public
List
<
Integer
>
getAxes
(){
return
getLayerSymbol
().
getIntTupleValue
(
AllPredefinedLayers
.
AXES_NAME
).
get
();
}
public
List
<
Integer
>
getShape
(){
return
getLayerSymbol
().
getIntTupleValue
(
AllPredefinedLayers
.
SHAPE_NAME
).
get
();
}
public
int
getLayers
(){
return
getLayerSymbol
().
getIntValue
(
AllPredefinedLayers
.
LAYERS_NAME
).
get
();
}
...
...
@@ -161,10 +171,6 @@ public class ArchitectureElementData {
return
getLayerSymbol
().
getBooleanValue
(
AllPredefinedLayers
.
FLATTEN_PARAMETER_NAME
).
get
();
}
public
List
<
Integer
>
getShape
()
{
return
getLayerSymbol
().
getIntTupleValue
(
AllPredefinedLayers
.
SHAPE_NAME
).
get
();
}
@Nullable
public
String
getPoolType
(){
return
getLayerSymbol
().
getStringValue
(
AllPredefinedLayers
.
POOL_TYPE_NAME
).
get
();
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/generator/ArchitectureSupportChecker.java
View file @
b14b8540
/* (c) https://github.com/MontiCore/monticore */
package
de.monticore.lang.monticar.cnnarch.generator
;
import
de.monticore.lang.monticar.cnnarch._symboltable.*
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol
;
...
...
@@ -16,7 +17,7 @@ public abstract class ArchitectureSupportChecker {
// Overload functions returning always true to enable the features
protected
boolean
checkMultipleStreams
(
ArchitectureSymbol
architecture
)
{
if
(
architecture
.
get
Stream
s
().
size
()
!=
1
)
{
if
(
architecture
.
get
NetworkInstruction
s
().
size
()
!=
1
)
{
Log
.
error
(
"This cnn architecture has multiple instructions, "
+
"which is currently not supported by the code generator. "
,
architecture
.
getSourcePosition
());
...
...
@@ -66,7 +67,7 @@ public abstract class ArchitectureSupportChecker {
}
private
boolean
hasConstant
(
ArchitectureElementSymbol
element
)
{
ArchitectureElementSymbol
resolvedElement
=
element
.
getResolvedThis
().
get
();
ArchitectureElementSymbol
resolvedElement
=
(
ArchitectureElementSymbol
)
element
.
getResolvedThis
().
get
();
if
(
resolvedElement
instanceof
CompositeElementSymbol
)
{
List
<
ArchitectureElementSymbol
>
constructedElements
=
((
CompositeElementSymbol
)
resolvedElement
).
getElements
();
...
...
@@ -85,8 +86,8 @@ public abstract class ArchitectureSupportChecker {
}
protected
boolean
checkConstants
(
ArchitectureSymbol
architecture
)
{
for
(
SerialCompositeElementSymbol
stream
:
architecture
.
get
Stream
s
())
{
for
(
ArchitectureElementSymbol
element
:
stream
.
getElements
())
{
for
(
NetworkInstructionSymbol
networkInstruction
:
architecture
.
get
NetworkInstruction
s
())
{
for
(
ArchitectureElementSymbol
element
:
networkInstruction
.
getBody
()
.
getElements
())
{
if
(
hasConstant
(
element
))
{
Log
.
error
(
"This cnn architecture has a constant, which is currently not supported by the code generator."
,
architecture
.
getSourcePosition
());
...
...
@@ -109,8 +110,8 @@ public abstract class ArchitectureSupportChecker {
}
protected
boolean
checkOutputAsInput
(
ArchitectureSymbol
architecture
)
{
for
(
SerialCompositeElementSymbol
stream
:
architecture
.
get
Stream
s
())
{
for
(
ArchitectureElementSymbol
element
:
stream
.
getFirstAtomicElements
())
{
for
(
NetworkInstructionSymbol
networkInstruction
:
architecture
.
get
NetworkInstruction
s
())
{
for
(
ArchitectureElementSymbol
element
:
networkInstruction
.
getBody
()
.
getFirstAtomicElements
())
{
if
(
element
.
isOutput
())
{
Log
.
error
(
"This cnn architecture uses an output as an input, which is currently not supported by the code generator."
,
architecture
.
getSourcePosition
());
...
...
@@ -122,6 +123,18 @@ public abstract class ArchitectureSupportChecker {
return
true
;
}
protected
boolean
checkUnroll
(
ArchitectureSymbol
architecture
)
{
for
(
NetworkInstructionSymbol
networkInstruction
:
architecture
.
getNetworkInstructions
())
{
if
(
networkInstruction
.
isUnroll
())
{
Log
.
error
(
"This cnn architecture uses unrolls, which are currently not supported by the code generator."
,
architecture
.
getSourcePosition
());
return
false
;
}
}
return
true
;
}
public
boolean
check
(
ArchitectureSymbol
architecture
)
{
return
checkMultipleStreams
(
architecture
)
&&
checkMultipleInputs
(
architecture
)
...
...
@@ -129,6 +142,7 @@ public abstract class ArchitectureSupportChecker {
&&
checkMultiDimensionalOutput
(
architecture
)
&&
checkConstants
(
architecture
)
&&
checkLayerVariables
(
architecture
)
&&
checkOutputAsInput
(
architecture
);
&&
checkOutputAsInput
(
architecture
)
&&
checkUnroll
(
architecture
);
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchTemplateController.java
View file @
b14b8540
...
...
@@ -138,17 +138,50 @@ public abstract class CNNArchTemplateController {
for
(
VariableSymbol
element
:
getArchitecture
().
getInputs
()){
list
.
add
(
nameManager
.
getName
(
element
));
}
list
.
removeAll
(
Collections
.
singleton
(
null
));
return
list
;
}
public
List
<
String
>
getArchitectureOutputs
(){
List
<
String
>
list
=
new
ArrayList
<>();
for
(
VariableSymbol
element
:
getArchitecture
().
getOutputs
()){
list
.
add
(
nameManager
.
getName
(
element
));
if
(
nameManager
.
getName
(
element
)
!=
null
&&
!
list
.
contains
(
nameManager
.
getName
(
element
)))
{
list
.
add
(
nameManager
.
getName
(
element
));
}
}
return
list
;
}
public
List
<
VariableSymbol
>
getArchitectureInputSymbols
(){
Set
<
String
>
names
=
new
HashSet
();
List
<
VariableSymbol
>
noDuplicates
=
new
ArrayList
();
for
(
VariableSymbol
inputs
:
getArchitecture
().
getInputs
())
{
if
(
getName
(
inputs
)
!=
null
&&
!
names
.
contains
(
getName
(
inputs
)))
{
names
.
add
(
getName
(
inputs
));
noDuplicates
.
add
(
inputs
);
}
}
return
noDuplicates
;
}
public
List
<
VariableSymbol
>
getArchitectureOutputSymbols
(){
Set
<
String
>
names
=
new
HashSet
();
List
<
VariableSymbol
>
noDuplicates
=
new
ArrayList
();
for
(
VariableSymbol
output
:
getArchitecture
().
getOutputs
())
{
if
(
getName
(
output
)
!=
null
&&
!
names
.
contains
(
getName
(
output
)))
{
names
.
add
(
getName
(
output
));
noDuplicates
.
add
(
output
);
}
}
return
noDuplicates
;
}
public
String
getComponentName
(){
return
getArchitecture
().
getComponentName
();
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/generator/ConfigurationData.java
View file @
b14b8540
...
...
@@ -2,6 +2,7 @@
package
de.monticore.lang.monticar.cnnarch.generator
;
import
de.monticore.lang.monticar.cnntrain._symboltable.*
;
import
static
de
.
monticore
.
lang
.
monticar
.
cnntrain
.
helper
.
ConfigEntryNameConstants
.*;
import
java.util.ArrayList
;
import
java.util.HashMap
;
...
...
@@ -61,11 +62,8 @@ public class ConfigurationData {
return
getConfiguration
().
getEntry
(
"context"
).
getValue
().
toString
();
}
public
String
getEvalMetric
()
{
if
(!
getConfiguration
().
getEntryMap
().
containsKey
(
"eval_metric"
))
{
return
null
;
}
return
getConfiguration
().
getEntry
(
"eval_metric"
).
getValue
().
toString
();
public
Map
<
String
,
Object
>
getEvalMetric
()
{
return
getMultiParamEntry
(
EVAL_METRIC
,
"name"
);
}
public
String
getLossName
()
{
...
...
@@ -130,4 +128,47 @@ public class ConfigurationData {
}
return
mapToStrings
;
}
public
Boolean
getSaveAttentionImage
()
{
if
(!
getConfiguration
().
getEntryMap
().
containsKey
(
"save_attention_image"
))
{
return
null
;
}
return
(
Boolean
)
getConfiguration
().
getEntry
(
"save_attention_image"
).
getValue
().
getValue
();
}
public
Boolean
getUseTeacherForcing
()
{
if
(!
getConfiguration
().
getEntryMap
().
containsKey
(
"use_teacher_forcing"
))
{
return
null
;
}
return
(
Boolean
)
getConfiguration
().
getEntry
(
"use_teacher_forcing"
).
getValue
().
getValue
();
}
protected
Map
<
String
,
Object
>
getMultiParamEntry
(
final
String
key
,
final
String
valueName
)
{
if
(!
configurationContainsKey
(
key
))
{
return
null
;
}
Map
<
String
,
Object
>
resultView
=
new
HashMap
<>();
ValueSymbol
value
=
this
.
getConfiguration
().
getEntryMap
().
get
(
key
).
getValue
();
if
(
value
instanceof
MultiParamValueSymbol
)
{
MultiParamValueSymbol
multiParamValue
=
(
MultiParamValueSymbol
)
value
;
resultView
.
put
(
valueName
,
multiParamValue
.
getValue
());
resultView
.
putAll
(
multiParamValue
.
getParameters
());
}
else
{
resultView
.
put
(
valueName
,
value
.
getValue
());
}
return
resultView
;
}
protected
Boolean
configurationContainsKey
(
final
String
key
)
{
return
this
.
getConfiguration
().
getEntryMap
().
containsKey
(
key
);
}
protected
Object
retrieveConfigurationEntryValueByKey
(
final
String
key
)
{
return
this
.
getConfiguration
().
getEntry
(
key
).
getValue
().
getValue
();
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/generator/LayerNameCreator.java
View file @
b14b8540
...
...
@@ -11,17 +11,21 @@ import java.util.*;
public
class
LayerNameCreator
{
private
Map
<
ArchitectureElementSymbol
,
String
>
elementToName
=
new
HashMap
<>();
private
Map
<
String
,
ArchitectureElementSymbol
>
nameToElement
=
new
Hash
Map
<>();
private
Set
<
String
>
names
=
new
Hash
Set
<>();
public
LayerNameCreator
(
ArchitectureSymbol
architecture
)
{
int
stage
=
1
;
for
(
SerialCompositeElementSymbol
stream
:
architecture
.
getStreams
())
{
stage
=
name
(
stream
,
stage
,
new
ArrayList
<>());
}
}
for
(
NetworkInstructionSymbol
networkInstruction
:
architecture
.
getNetworkInstructions
())
{
stage
=
name
(
networkInstruction
.
getBody
(),
stage
,
new
ArrayList
<>());
if
(
networkInstruction
.
isUnroll
())
{
UnrollInstructionSymbol
unroll
=
(
UnrollInstructionSymbol
)
networkInstruction
;
public
ArchitectureElementSymbol
getArchitectureElement
(
String
name
){
return
nameToElement
.
get
(
name
);
for
(
SerialCompositeElementSymbol
body
:
unroll
.
getResolvedBodies
())
{
stage
=
name
(
body
,
stage
,
new
ArrayList
<>());
}
}
}
}
public
String
getName
(
ArchitectureElementSymbol
architectureElement
){
...
...
@@ -31,17 +35,17 @@ public class LayerNameCreator {
protected
int
name
(
ArchitectureElementSymbol
architectureElement
,
int
stage
,
List
<
Integer
>
streamIndices
){
if
(
architectureElement
instanceof
SerialCompositeElementSymbol
)
{
return
nameSerialComposite
((
SerialCompositeElementSymbol
)
architectureElement
,
stage
,
streamIndices
);
}
else
if
(
architectureElement
instanceof
ParallelCompositeElementSymbol
){
}
else
if
(
architectureElement
instanceof
ParallelCompositeElementSymbol
)
{
return
nameParallelComposite
((
ParallelCompositeElementSymbol
)
architectureElement
,
stage
,
streamIndices
);
}
else
{
if
(
architectureElement
.
isAtomic
()){
}
else
{
if
(
architectureElement
.
isAtomic
())
{
if
(
architectureElement
.
getMaxSerialLength
().
get
()
>
0
){
return
add
(
architectureElement
,
stage
,
streamIndices
);
}
else
{
return
stage
;
}
}
else
{
ArchitectureElementSymbol
resolvedElement
=
architectureElement
.
getResolvedThis
().
get
();
ArchitectureElementSymbol
resolvedElement
=
(
ArchitectureElementSymbol
)
architectureElement
.
getResolvedThis
().
get
();
return
name
(
resolvedElement
,
stage
,
streamIndices
);
}
}
...
...
@@ -75,24 +79,15 @@ public class LayerNameCreator {
if
(!
elementToName
.
containsKey
(
architectureElement
))
{
String
name
=
createName
(
architectureElement
,
endStage
,
streamIndices
);
while
(
nameToElement
.
containsKey
(
name
))
{
endStage
++;
name
=
createName
(
architectureElement
,
endStage
,
streamIndices
);
if
(!(
architectureElement
instanceof
VariableSymbol
))
{
while
(
names
.
contains
(
name
))
{
endStage
++;
name
=
createName
(
architectureElement
,
endStage
,
streamIndices
);
}
}
elementToName
.
put
(
architectureElement
,
name
);
boolean
isLayerVariable
=
false
;
if
(
architectureElement
instanceof
VariableSymbol
)
{
isLayerVariable
=
((
VariableSymbol
)
architectureElement
).
getType
()
==
VariableSymbol
.
Type
.
LAYER
;
}
// Do not map names of layer variables to their respective element since the names are not unique
// for now the name to element mapping is not used anywhere so it doesn't matter
if
(!
isLayerVariable
)
{
nameToElement
.
put
(
name
,
architectureElement
);
}
names
.
add
(
name
);
}
return
endStage
;
}
...
...
@@ -101,23 +96,21 @@ public class LayerNameCreator {
if
(
architectureElement
instanceof
VariableSymbol
)
{
VariableSymbol
element
=
(
VariableSymbol
)
architectureElement
;
String
name
=
createBaseName
(
architectureElement
);
String
name
=
createBaseName
(
architectureElement
)
+
"_"
;
if
(
element
.
getType
()
==
VariableSymbol
.
Type
.
IO
)
{
if
(
element
.
getArrayAccess
().
isPresent
()){
int
arrayAccess
=
element
.
getArrayAccess
().
get
().
getIntValue
().
get
();
name
=
name
+
"_"
+
arrayAccess
+
"_"
;
}
else
{
name
=
name
+
"_"
;
}
}
else
if
(
element
.
getType
()
==
VariableSymbol
.
Type
.
LAYER
)
{
if
(
element
.
getType
()
==
VariableSymbol
.
Type
.
LAYER
)
{
if
(
element
.
getMember
()
==
VariableSymbol
.
Member
.
STATE
)
{
name
=
name
+
"
_
state_"
;
name
=
name
+
"state_"
;
}
else
{
name
=
name
+
"
_
output_"
;
name
=
name
+
"output_"
;
}
}
if
(
element
.
getArrayAccess
().
isPresent
()){
int
arrayAccess
=
element
.
getArrayAccess
().
get
().
getIntValue
().
get
();
name
=
name
+
arrayAccess
+
"_"
;
}
return
name
;
}
else
{
return
createBaseName
(
architectureElement
)
+
stage
+
createStreamPostfix
(
streamIndices
)
+
"_"
;
...
...
@@ -153,4 +146,3 @@ public class LayerNameCreator {
return
stringBuilder
.
toString
();
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/generator/LayerSupportChecker.java
View file @
b14b8540
/* (c) https://github.com/MontiCore/monticore */
package
de.monticore.lang.monticar.cnnarch.generator
;
import
de.monticore.lang.monticar.cnnarch.
predefined.AllPredefinedLayers
;
import
de.monticore.lang.monticar.cnnarch.
_symboltable.*
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol
;
import
de.se_rwth.commons.logging.Log
;
...
...
@@ -19,11 +18,10 @@ public abstract class LayerSupportChecker {
protected
List
<
String
>
supportedLayerList
=
new
ArrayList
<>();
private
boolean
isSupportedLayer
(
ArchitectureElementSymbol
element
){
ArchitectureElementSymbol
resolvedElement
=
element
.
getResolvedThis
().
get
();
List
<
ArchitectureElementSymbol
>
constructLayerElemList
;
ArchitectureElementSymbol
resolvedElement
=
(
ArchitectureElementSymbol
)
element
.
getResolvedThis
().
get
();
if
(
resolvedElement
instanceof
CompositeElementSymbol
)
{
constructLayerElemList
=
((
CompositeElementSymbol
)
resolvedElement
).
getElements
();
List
<
ArchitectureElementSymbol
>
constructLayerElemList
=
((
CompositeElementSymbol
)
resolvedElement
).
getElements
();
for
(
ArchitectureElementSymbol
constructedLayerElement
:
constructLayerElemList
)
{
if
(!
isSupportedLayer
(
constructedLayerElement
))
{
return
false
;
...
...
@@ -63,8 +61,8 @@ public abstract class LayerSupportChecker {
}
public
boolean
check
(
ArchitectureSymbol
architecture
)
{
for
(
CompositeElementSymbol
stream
:
architecture
.
get
Stream
s
())
{
for
(
ArchitectureElementSymbol
element
:
stream
.
getElements
())
{
for
(
NetworkInstructionSymbol
networkInstructions
:
architecture
.
get
NetworkInstruction
s
())
{
for
(
ArchitectureElementSymbol
element
:
networkInstructions
.
getBody
()
.
getElements
())
{
if
(!
isSupportedLayer
(
element
))
{
return
false
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment