Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
languages
CNNArchLang
Commits
cbbb702c
Commit
cbbb702c
authored
Oct 06, 2019
by
Christian Fuß
Browse files
added SwapAxes layer. Adjusted FullyConnected layer to work with RNN states.
parent
13b87e01
Pipeline
#191051
failed with stages
in 37 seconds
Changes
6
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/PredefinedLayerDeclaration.java
View file @
cbbb702c
...
...
@@ -63,7 +63,11 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
}
public
boolean
isTrainable
(
VariableSymbol
.
Member
member
)
{
return
true
;
if
(
member
==
VariableSymbol
.
Member
.
STATE
||
member
==
VariableSymbol
.
Member
.
OUTPUT
){
return
false
;
}
else
{
return
true
;
}
}
/**
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java
View file @
cbbb702c
...
...
@@ -58,6 +58,7 @@ public class AllPredefinedLayers {
public
static
final
String
REDUCE_SUM_NAME
=
"ReduceSum"
;
public
static
final
String
EXPAND_DIMS_NAME
=
"ExpandDims"
;
public
static
final
String
MULTIPLY_NAME
=
"Multiply"
;
public
static
final
String
SWAPAXES_NAME
=
"SwapAxes"
;
//predefined argument names
public
static
final
String
KERNEL_NAME
=
"kernel"
;
...
...
@@ -86,6 +87,7 @@ public class AllPredefinedLayers {
public
static
final
String
WIDTH_NAME
=
"width"
;
public
static
final
String
REPEATS_NAME
=
"n"
;
public
static
final
String
AXIS_NAME
=
"axis"
;
public
static
final
String
AXES_NAME
=
"axes"
;
//possible String values
public
static
final
String
PADDING_VALID
=
"valid"
;
...
...
@@ -124,7 +126,8 @@ public class AllPredefinedLayers {
Repeat
.
create
(),
ReduceSum
.
create
(),
ExpandDims
.
create
(),
Multiply
.
create
());
Multiply
.
create
(),
SwapAxes
.
create
());
}
public
static
List
<
UnrollDeclarationSymbol
>
createUnrollList
(){
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Concatenate.java
View file @
cbbb702c
...
...
@@ -38,37 +38,67 @@ public class Concatenate extends PredefinedLayerDeclaration {
@Override
public
List
<
ArchTypeSymbol
>
computeOutputTypes
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
int
height
=
layer
.
getInputTypes
().
get
(
0
).
getHeight
()
;
int
width
=
layer
.
getInputTypes
().
get
(
0
).
getWidth
()
;
int
height
=
0
;
int
width
=
0
;
int
channels
=
0
;
for
(
ArchTypeSymbol
inputShape
:
layer
.
getInputTypes
())
{
channels
+=
inputShape
.
getChannels
();
}
int
dim
=
layer
.
getIntValue
(
AllPredefinedLayers
.
DIM_NAME
).
get
();
List
<
String
>
range
=
computeStartAndEndValue
(
layer
.
getInputTypes
(),
(
x
,
y
)
->
x
.
isLessThan
(
y
)
?
x
:
y
,
(
x
,
y
)
->
x
.
isLessThan
(
y
)
?
y
:
x
);
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
channels
)
.
height
(
height
)
.
width
(
width
)
.
elementType
(
range
.
get
(
0
),
range
.
get
(
1
))
.
build
());
if
(
dim
==
0
){
for
(
ArchTypeSymbol
inputShape
:
layer
.
getInputTypes
())
{
channels
+=
inputShape
.
getChannels
();
}
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
channels
)
.
height
(
layer
.
getInputTypes
().
get
(
0
).
getHeight
())
.
width
(
layer
.
getInputTypes
().
get
(
0
).
getWidth
())
.
elementType
(
range
.
get
(
0
),
range
.
get
(
1
))
.
build
());
}
else
if
(
dim
==
1
){
for
(
ArchTypeSymbol
inputShape
:
layer
.
getInputTypes
())
{
height
+=
inputShape
.
getHeight
();
}
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
layer
.
getInputTypes
().
get
(
0
).
getChannels
())
.
height
(
height
)
.
width
(
layer
.
getInputTypes
().
get
(
0
).
getWidth
())
.
elementType
(
range
.
get
(
0
),
range
.
get
(
1
))
.
build
());
}
else
if
(
dim
==
2
){
for
(
ArchTypeSymbol
inputShape
:
layer
.
getInputTypes
())
{
width
+=
inputShape
.
getWidth
();
}
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
layer
.
getInputTypes
().
get
(
0
).
getChannels
())
.
height
(
layer
.
getInputTypes
().
get
(
0
).
getHeight
())
.
width
(
width
)
.
elementType
(
range
.
get
(
0
),
range
.
get
(
1
))
.
build
());
}
else
{
return
new
ArrayList
<>();
}
}
@Override
public
void
checkInput
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
if
(!
inputTypes
.
isEmpty
())
{
List
<
Integer
>
channelList
=
new
ArrayList
<>();
List
<
Integer
>
heightList
=
new
ArrayList
<>();
List
<
Integer
>
widthList
=
new
ArrayList
<>();
for
(
ArchTypeSymbol
shape
:
inputTypes
){
heightList
.
add
(
shape
.
getHeight
());
widthList
.
add
(
shape
.
getWidth
());
channelList
.
add
(
shape
.
getChannels
());
}
int
countEqualcHannels
=
(
int
)
channelList
.
stream
().
distinct
().
count
();
int
countEqualHeights
=
(
int
)
heightList
.
stream
().
distinct
().
count
();
int
countEqualWidths
=
(
int
)
widthList
.
stream
().
distinct
().
count
();
if
(
countEqualHeights
!=
1
||
countEqualWidths
!=
1
){
if
(
countEqualHeights
!=
1
&&
countEqualWidths
!=
1
&&
countEqualcHannels
!=
1
){
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_ELEMENT_INPUT_SHAPE
+
" Invalid layer input. "
+
"Concatenation of inputs with different resolutions is not possible. "
+
"Input channels: "
+
Joiners
.
COMMA
.
join
(
channelList
)
+
". "
+
"Input heights: "
+
Joiners
.
COMMA
.
join
(
heightList
)
+
". "
+
"Input widths: "
+
Joiners
.
COMMA
.
join
(
widthList
)
+
". "
,
layer
.
getSourcePosition
());
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/FullyConnected.java
View file @
cbbb702c
...
...
@@ -59,12 +59,21 @@ public class FullyConnected extends PredefinedLayerDeclaration {
.
build
());
}
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
inputType
.
getChannels
())
.
height
(
units
)
.
width
(
1
)
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
// if input is an RNN state or output. Can be used to store states in layer variables
if
(
layer
.
getInputElement
().
get
()
instanceof
VariableSymbol
&&
(((
VariableSymbol
)(
layer
.
getInputElement
().
get
())).
getMember
()
==
VariableSymbol
.
Member
.
STATE
||
((
VariableSymbol
)((
ArchitectureElementSymbol
)
layer
.
getInputElement
().
get
())).
getMember
()
==
VariableSymbol
.
Member
.
OUTPUT
)){
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
inputType
.
getChannels
())
.
height
(
units
)
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
}
else
{
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
inputType
.
getChannels
())
.
height
(
units
)
.
width
(
1
)
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
}
}
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Split.java
View file @
cbbb702c
...
...
@@ -34,6 +34,11 @@ public class Split extends PredefinedLayerDeclaration {
super
(
AllPredefinedLayers
.
SPLIT_NAME
);
}
@Override
public
boolean
isTrainable
()
{
return
false
;
}
@Override
public
List
<
ArchTypeSymbol
>
computeOutputTypes
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
ArchTypeSymbol
inputShape
=
layer
.
getInputTypes
().
get
(
0
);
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/SwapAxes.java
0 → 100644
View file @
cbbb702c
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnnarch.predefined
;
import
de.monticore.lang.monticar.cnnarch._symboltable.*
;
import
de.monticore.lang.monticar.cnnarch.helper.ErrorCodes
;
import
de.se_rwth.commons.logging.Log
;
import
java.util.ArrayList
;
import
java.util.Arrays
;
import
java.util.Collections
;
import
java.util.List
;
public
class
SwapAxes
extends
PredefinedLayerDeclaration
{
private
SwapAxes
()
{
super
(
AllPredefinedLayers
.
SWAPAXES_NAME
);
}
@Override
public
List
<
ArchTypeSymbol
>
computeOutputTypes
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
int
firstAxis
=
layer
.
getIntTupleValue
(
AllPredefinedLayers
.
AXES_NAME
).
get
().
get
(
0
);
int
secondAxis
=
layer
.
getIntTupleValue
(
AllPredefinedLayers
.
AXES_NAME
).
get
().
get
(
1
);
if
((
firstAxis
==
0
&&
secondAxis
==
1
)
||
(
firstAxis
==
1
&&
secondAxis
==
0
)){
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
layer
.
getInputTypes
().
get
(
0
).
getHeight
())
.
height
(
layer
.
getInputTypes
().
get
(
0
).
getChannels
())
.
width
(
layer
.
getInputTypes
().
get
(
0
).
getWidth
())
.
elementType
(
layer
.
getInputTypes
().
get
(
0
).
getDomain
())
.
build
());
}
else
if
((
firstAxis
==
0
&&
secondAxis
==
2
)
||
(
firstAxis
==
2
&&
secondAxis
==
0
)){
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
layer
.
getInputTypes
().
get
(
0
).
getWidth
())
.
height
(
layer
.
getInputTypes
().
get
(
0
).
getHeight
())
.
width
(
layer
.
getInputTypes
().
get
(
0
).
getChannels
())
.
elementType
(
layer
.
getInputTypes
().
get
(
0
).
getDomain
())
.
build
());
}
else
if
((
firstAxis
==
1
&&
secondAxis
==
2
)
||
(
firstAxis
==
2
&&
secondAxis
==
1
)){
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
layer
.
getInputTypes
().
get
(
0
).
getChannels
())
.
height
(
layer
.
getInputTypes
().
get
(
0
).
getWidth
())
.
width
(
layer
.
getInputTypes
().
get
(
0
).
getHeight
())
.
elementType
(
layer
.
getInputTypes
().
get
(
0
).
getDomain
())
.
build
());
}
else
{
if
((
firstAxis
<
0
||
firstAxis
>
2
||
secondAxis
<
0
||
secondAxis
>
2
)){
Log
.
error
(
"0"
+
ErrorCodes
.
ILLEGAL_PARAMETER_VALUE
+
" Illegal value for parameter axes. Values must be between 0 and 2"
,
layer
.
getSourcePosition
());
}
return
new
ArrayList
<>();
}
}
@Override
public
void
checkInput
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
errorIfInputSizeIsNotOne
(
inputTypes
,
layer
);
}
public
static
SwapAxes
create
(){
SwapAxes
declaration
=
new
SwapAxes
();
List
<
ParameterSymbol
>
parameters
=
new
ArrayList
<>(
Arrays
.
asList
(
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
AXES_NAME
)
.
constraints
(
Constraints
.
INTEGER_TUPLE
)
.
build
()));
declaration
.
setParameters
(
parameters
);
return
declaration
;
}
}
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment