Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
languages
CNNArchLang
Commits
09af254b
Commit
09af254b
authored
Oct 24, 2019
by
Sebastian Nickels
Browse files
Added Squeeze layer, updated RNN models, cleaned up some layers
parent
cbbb702c
Changes
13
Hide whitespace changes
Inline
Side-by-side
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/Constraints.java
View file @
09af254b
...
...
@@ -184,6 +184,70 @@ public enum Constraints {
return
AllPredefinedLayers
.
POOL_MAX
+
" or "
+
AllPredefinedLayers
.
POOL_AVG
;
}
},
NULLABLE_AXIS
{
@Override
public
boolean
isValid
(
ArchSimpleExpressionSymbol
exp
)
{
if
(
exp
.
getIntValue
().
isPresent
()){
int
intValue
=
exp
.
getIntValue
().
get
();
return
intValue
>=
-
1
&&
intValue
<=
2
;
// -1 is null
}
return
false
;
}
@Override
public
String
msgString
()
{
return
"an axis between 0 and 2 or -1"
;
}
},
NULLABLE_AXIS_WITHOUT_2
{
@Override
public
boolean
isValid
(
ArchSimpleExpressionSymbol
exp
)
{
if
(
exp
.
getIntValue
().
isPresent
()){
int
intValue
=
exp
.
getIntValue
().
get
();
return
intValue
>=
-
1
&&
intValue
<=
2
;
// -1 is null
}
return
false
;
}
@Override
public
String
msgString
()
{
return
"an axis between 0 and 1 or -1"
;
}
},
AXIS
{
@Override
public
boolean
isValid
(
ArchSimpleExpressionSymbol
exp
)
{
if
(
exp
.
getIntValue
().
isPresent
()){
int
intValue
=
exp
.
getIntValue
().
get
();
return
intValue
>=
0
&&
intValue
<=
2
;
}
return
false
;
}
@Override
public
String
msgString
()
{
return
"an axis between 0 and 2"
;
}
},
AXIS_WITHOUT_2
{
@Override
public
boolean
isValid
(
ArchSimpleExpressionSymbol
exp
)
{
if
(
exp
.
getIntValue
().
isPresent
()){
int
intValue
=
exp
.
getIntValue
().
get
();
return
intValue
>=
0
&&
intValue
<=
1
;
}
return
false
;
}
@Override
public
String
msgString
()
{
return
"an axis between 0 and 1"
;
}
};
protected
abstract
boolean
isValid
(
ArchSimpleExpressionSymbol
exp
);
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/PredefinedLayerDeclaration.java
View file @
09af254b
...
...
@@ -121,20 +121,6 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
}
}
protected
void
errorIfAxisNotFeasible
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
){
if
((!
layer
.
getStringValue
(
AllPredefinedLayers
.
AXIS_NAME
).
isPresent
()
&&
layer
.
getIntValue
(
AllPredefinedLayers
.
AXIS_NAME
).
get
()
>
1
)){
Log
.
error
(
"0"
+
ErrorCodes
.
ILLEGAL_PARAMETER_VALUE
+
" Illegal value for parameter axis. Value must be None, 0 or 1"
,
layer
.
getSourcePosition
());
}
}
protected
void
errorIfDimNotFeasible
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
){
if
(
layer
.
getIntValue
(
AllPredefinedLayers
.
DIM_NAME
).
get
()
>
1
){
Log
.
error
(
"0"
+
ErrorCodes
.
ILLEGAL_PARAMETER_VALUE
+
" Illegal value for parameter dim. Value must be 0 or 1"
,
layer
.
getSourcePosition
());
}
}
protected
void
errorIfInputNotFeasibleForDotProduct
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
){
if
(!(
layer
.
getInputTypes
().
get
(
1
).
getHeight
()
==
layer
.
getInputTypes
().
get
(
0
).
getWidth
())){
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_ELEMENT_INPUT_SHAPE
+
" Invalid layer input. Dot Product cannot be applied to input 1 with height "
+
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java
View file @
09af254b
...
...
@@ -55,6 +55,7 @@ public class AllPredefinedLayers {
public
static
final
String
ARG_MAX_NAME
=
"ArgMax"
;
public
static
final
String
DOT_NAME
=
"Dot"
;
public
static
final
String
REPEAT_NAME
=
"Repeat"
;
public
static
final
String
SQUEEZE_NAME
=
"Squeeze"
;
public
static
final
String
REDUCE_SUM_NAME
=
"ReduceSum"
;
public
static
final
String
EXPAND_DIMS_NAME
=
"ExpandDims"
;
public
static
final
String
MULTIPLY_NAME
=
"Multiply"
;
...
...
@@ -80,7 +81,6 @@ public class AllPredefinedLayers {
public
static
final
String
LAYERS_NAME
=
"layers"
;
public
static
final
String
INPUT_DIM_NAME
=
"input_dim"
;
public
static
final
String
OUTPUT_DIM_NAME
=
"output_dim"
;
public
static
final
String
DIM_NAME
=
"dim"
;
public
static
final
String
BIDIRECTIONAL_NAME
=
"bidirectional"
;
public
static
final
String
FLATTEN_PARAMETER_NAME
=
"flatten"
;
public
static
final
String
MAX_LENGTH_NAME
=
"max_length"
;
...
...
@@ -124,6 +124,7 @@ public class AllPredefinedLayers {
ArgMax
.
create
(),
Dot
.
create
(),
Repeat
.
create
(),
Squeeze
.
create
(),
ReduceSum
.
create
(),
ExpandDims
.
create
(),
Multiply
.
create
(),
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/BaseRNN.java
View file @
09af254b
...
...
@@ -29,6 +29,8 @@ import java.util.List;
abstract
public
class
BaseRNN
extends
PredefinedLayerDeclaration
{
protected
int
numberOfStates
=
1
;
public
BaseRNN
(
String
name
)
{
super
(
name
);
}
...
...
@@ -47,8 +49,9 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
int
layers
=
layer
.
getIntValue
(
AllPredefinedLayers
.
LAYERS_NAME
).
get
();
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
bidirectional
?
2
*
layers
:
layers
)
.
height
(
units
)
.
channels
(
numberOfStates
)
.
height
(
bidirectional
?
2
*
layers
:
layers
)
.
width
(
units
)
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
}
...
...
@@ -69,9 +72,9 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
if
(
member
==
VariableSymbol
.
Member
.
STATE
)
{
errorIfInputSizeIsNotOne
(
inputTypes
,
layer
);
errorIfInputChannelSizeIsInvalid
(
inputTypes
,
layer
,
bidirectional
?
2
*
layers
:
layer
s
);
errorIfInputHeightIsInvalid
(
inputTypes
,
layer
,
unit
s
);
errorIfInputWidthIsInvalid
(
inputTypes
,
layer
,
1
);
errorIfInputChannelSizeIsInvalid
(
inputTypes
,
layer
,
numberOfState
s
);
errorIfInputHeightIsInvalid
(
inputTypes
,
layer
,
bidirectional
?
2
*
layers
:
layer
s
);
errorIfInputWidthIsInvalid
(
inputTypes
,
layer
,
units
);
}
else
{
errorIfInputSizeIsNotOne
(
inputTypes
,
layer
);
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Concatenate.java
View file @
09af254b
...
...
@@ -38,74 +38,61 @@ public class Concatenate extends PredefinedLayerDeclaration {
@Override
public
List
<
ArchTypeSymbol
>
computeOutputTypes
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
int
height
=
0
;
int
width
=
0
;
int
channels
=
0
;
int
channels
=
layer
.
getInputTypes
().
get
(
0
).
getChannels
()
;
int
height
=
layer
.
getInputTypes
().
get
(
0
).
getHeight
()
;
int
width
=
layer
.
getInputTypes
().
get
(
0
).
getWidth
()
;
int
dim
=
layer
.
getIntValue
(
AllPredefinedLayers
.
DIM
_NAME
).
get
();
int
axis
=
layer
.
getIntValue
(
AllPredefinedLayers
.
AXIS
_NAME
).
get
();
List
<
String
>
range
=
computeStartAndEndValue
(
layer
.
getInputTypes
(),
(
x
,
y
)
->
x
.
isLessThan
(
y
)
?
x
:
y
,
(
x
,
y
)
->
x
.
isLessThan
(
y
)
?
y
:
x
);
if
(
dim
==
0
){
for
(
ArchTypeSymbol
inputShape
:
layer
.
getInputTypes
())
{
List
<
ArchTypeSymbol
>
types
=
layer
.
getInputTypes
();
types
.
remove
(
0
);
if
(
axis
==
0
)
{
for
(
ArchTypeSymbol
inputShape
:
types
)
{
channels
+=
inputShape
.
getChannels
();
}
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
channels
)
.
height
(
layer
.
getInputTypes
().
get
(
0
).
getHeight
())
.
width
(
layer
.
getInputTypes
().
get
(
0
).
getWidth
())
.
elementType
(
range
.
get
(
0
),
range
.
get
(
1
))
.
build
());
}
else
if
(
dim
==
1
){
for
(
ArchTypeSymbol
inputShape
:
layer
.
getInputTypes
())
{
}
else
if
(
axis
==
1
)
{
for
(
ArchTypeSymbol
inputShape
:
types
)
{
height
+=
inputShape
.
getHeight
();
}
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
layer
.
getInputTypes
().
get
(
0
).
getChannels
())
.
height
(
height
)
.
width
(
layer
.
getInputTypes
().
get
(
0
).
getWidth
())
.
elementType
(
range
.
get
(
0
),
range
.
get
(
1
))
.
build
());
}
else
if
(
dim
==
2
){
for
(
ArchTypeSymbol
inputShape
:
layer
.
getInputTypes
())
{
}
else
{
for
(
ArchTypeSymbol
inputShape
:
types
)
{
width
+=
inputShape
.
getWidth
();
}
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
layer
.
getInputTypes
().
get
(
0
).
getChannels
())
.
height
(
layer
.
getInputTypes
().
get
(
0
).
getHeight
())
.
width
(
width
)
.
elementType
(
range
.
get
(
0
),
range
.
get
(
1
))
.
build
());
}
else
{
return
new
ArrayList
<>();
}
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
channels
)
.
height
(
height
)
.
width
(
width
)
.
elementType
(
range
.
get
(
0
),
range
.
get
(
1
))
.
build
());
}
@Override
public
void
checkInput
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
if
(!
inputTypes
.
isEmpty
())
{
List
<
Integer
>
channelList
=
new
ArrayList
<>();
List
<
Integer
>
heightList
=
new
ArrayList
<>();
List
<
Integer
>
widthList
=
new
ArrayList
<>();
for
(
ArchTypeSymbol
shape
:
inputTypes
){
heightList
.
add
(
shape
.
getHeight
());
widthList
.
add
(
shape
.
getWidth
());
channelList
.
add
(
shape
.
getChannels
());
}
int
countEqualcHannels
=
(
int
)
channelList
.
stream
().
distinct
().
count
();
int
countEqualHeights
=
(
int
)
heightList
.
stream
().
distinct
().
count
();
int
countEqualWidths
=
(
int
)
widthList
.
stream
().
distinct
().
count
();
if
(
countEqualHeights
!=
1
&&
countEqualWidths
!=
1
&&
countEqualcHannels
!=
1
){
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_ELEMENT_INPUT_SHAPE
+
" Invalid layer input. "
+
"Concatenation of inputs with different resolutions is not possible. "
+
"Input channels: "
+
Joiners
.
COMMA
.
join
(
channelList
)
+
". "
+
"Input heights: "
+
Joiners
.
COMMA
.
join
(
heightList
)
+
". "
+
"Input widths: "
+
Joiners
.
COMMA
.
join
(
widthList
)
+
". "
,
layer
.
getSourcePosition
());
}
errorIfInputIsEmpty
(
inputTypes
,
layer
);
List
<
Integer
>
channelList
=
new
ArrayList
<>();
List
<
Integer
>
heightList
=
new
ArrayList
<>();
List
<
Integer
>
widthList
=
new
ArrayList
<>();
for
(
ArchTypeSymbol
shape
:
inputTypes
){
heightList
.
add
(
shape
.
getHeight
());
widthList
.
add
(
shape
.
getWidth
());
channelList
.
add
(
shape
.
getChannels
());
}
else
{
errorIfInputIsEmpty
(
inputTypes
,
layer
);
int
countEqualChannels
=
(
int
)
channelList
.
stream
().
distinct
().
count
();
int
countEqualHeights
=
(
int
)
heightList
.
stream
().
distinct
().
count
();
int
countEqualWidths
=
(
int
)
widthList
.
stream
().
distinct
().
count
();
if
(
countEqualHeights
!=
1
&&
countEqualWidths
!=
1
&&
countEqualChannels
!=
1
){
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_ELEMENT_INPUT_SHAPE
+
" Invalid layer input. "
+
"Concatenation of inputs with different resolutions is not possible. "
+
"Input channels: "
+
Joiners
.
COMMA
.
join
(
channelList
)
+
". "
+
"Input heights: "
+
Joiners
.
COMMA
.
join
(
heightList
)
+
". "
+
"Input widths: "
+
Joiners
.
COMMA
.
join
(
widthList
)
+
". "
,
layer
.
getSourcePosition
());
}
}
...
...
@@ -113,9 +100,9 @@ public class Concatenate extends PredefinedLayerDeclaration {
Concatenate
declaration
=
new
Concatenate
();
List
<
ParameterSymbol
>
parameters
=
new
ArrayList
<>(
Arrays
.
asList
(
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
DIM
_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
NON_NEGATIVE
)
.
defaultValue
(
1
)
.
name
(
AllPredefinedLayers
.
AXIS
_NAME
)
.
constraints
(
Constraints
.
AXIS
)
.
defaultValue
(
0
)
.
build
()));
declaration
.
setParameters
(
parameters
);
return
declaration
;
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/ExpandDims.java
View file @
09af254b
...
...
@@ -37,17 +37,17 @@ public class ExpandDims extends PredefinedLayerDeclaration {
@Override
public
List
<
ArchTypeSymbol
>
computeOutputTypes
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
int
axis
=
layer
.
getIntValue
(
AllPredefinedLayers
.
AXIS_NAME
).
get
();
int
dim
=
layer
.
getIntValue
(
AllPredefinedLayers
.
DIM_NAME
).
get
();
int
channels
=
layer
.
getInputTypes
().
get
(
0
).
getChannels
();
int
height
=
layer
.
getInputTypes
().
get
(
0
).
getHeight
();
int
width
=
layer
.
getInputTypes
().
get
(
0
).
getWidth
();
if
(
dim
==
0
)
{
if
(
axis
==
0
)
{
width
=
height
;
height
=
channels
;
channels
=
1
;
}
else
if
(
dim
==
1
)
{
}
else
if
(
axis
==
1
)
{
width
=
height
;
height
=
1
;
}
...
...
@@ -56,22 +56,22 @@ public class ExpandDims extends PredefinedLayerDeclaration {
.
channels
(
channels
)
.
height
(
height
)
.
width
(
width
)
.
elementType
(
"-oo"
,
"oo"
)
.
elementType
(
layer
.
getInputTypes
().
get
(
0
).
getDomain
()
)
.
build
());
}
@Override
public
void
checkInput
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
errorIfInputIsEmpty
(
inputTypes
,
layer
);
errorIf
DimNotFeasible
(
inputTypes
,
layer
);
errorIf
InputWidthIsInvalid
(
inputTypes
,
layer
,
1
);
}
public
static
ExpandDims
create
(){
ExpandDims
declaration
=
new
ExpandDims
();
List
<
ParameterSymbol
>
parameters
=
new
ArrayList
<>(
Arrays
.
asList
(
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
DIM
_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
NON_NEGATIVE
)
.
name
(
AllPredefinedLayers
.
AXIS
_NAME
)
.
constraints
(
Constraints
.
AXIS_WITHOUT_2
)
.
build
()));
declaration
.
setParameters
(
parameters
);
return
declaration
;
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/LSTM.java
View file @
09af254b
...
...
@@ -24,6 +24,8 @@ public class LSTM extends BaseRNN {
private
LSTM
()
{
super
(
AllPredefinedLayers
.
LSTM_NAME
);
numberOfStates
=
2
;
}
public
static
LSTM
create
()
{
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/ReduceSum.java
View file @
09af254b
...
...
@@ -23,6 +23,7 @@ package de.monticore.lang.monticar.cnnarch.predefined;
import
de.monticore.lang.monticar.cnnarch._symboltable.*
;
import
de.monticore.lang.monticar.cnnarch.helper.ErrorCodes
;
import
de.se_rwth.commons.logging.Log
;
import
org.fusesource.jansi.internal.Kernel32
;
import
java.util.ArrayList
;
import
java.util.Arrays
;
...
...
@@ -37,44 +38,33 @@ public class ReduceSum extends PredefinedLayerDeclaration {
@Override
public
List
<
ArchTypeSymbol
>
computeOutputTypes
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
int
axis
;
boolean
axisIsNone
=
layer
.
getStringValue
(
AllPredefinedLayers
.
AXIS_NAME
).
isPresent
();
int
axis
=
layer
.
getIntValue
(
AllPredefinedLayers
.
AXIS_NAME
).
get
();
int
channels
=
layer
.
getInputTypes
().
get
(
0
).
getChannels
();
int
height
=
layer
.
getInputTypes
().
get
(
0
).
getHeight
();
int
width
=
layer
.
getInputTypes
().
get
(
0
).
getWidth
();
if
(
axisIsNone
)
{
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
1
)
.
height
(
1
)
.
width
(
1
)
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
if
(
axis
==
0
)
{
height
=
1
;
}
else
if
(
axis
==
1
)
{
width
=
1
;
}
else
{
axis
=
layer
.
getIntValue
(
AllPredefinedLayers
.
AXIS_NAME
).
get
();
if
(
axis
==
0
)
{
height
=
1
;
}
else
{
width
=
1
;
}
channels
=
1
;
height
=
1
;
width
=
1
;
}
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
channels
)
.
height
(
height
)
.
width
(
width
)
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
}
}
@Override
public
void
checkInput
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
errorIfInputIsEmpty
(
inputTypes
,
layer
);
errorIfAxisNotFeasible
(
inputTypes
,
layer
);
}
public
static
ReduceSum
create
(){
...
...
@@ -82,6 +72,8 @@ public class ReduceSum extends PredefinedLayerDeclaration {
List
<
ParameterSymbol
>
parameters
=
new
ArrayList
<>(
Arrays
.
asList
(
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
AXIS_NAME
)
.
constraints
(
Constraints
.
NULLABLE_AXIS_WITHOUT_2
)
.
defaultValue
(-
1
)
.
build
()));
declaration
.
setParameters
(
parameters
);
return
declaration
;
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Repeat.java
View file @
09af254b
...
...
@@ -37,45 +37,44 @@ public class Repeat extends PredefinedLayerDeclaration {
@Override
public
List
<
ArchTypeSymbol
>
computeOutputTypes
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
int
repeats
=
layer
.
getIntValue
(
AllPredefinedLayers
.
REPEATS_NAME
).
get
();
int
axis
;
boolean
axisIsNone
=
layer
.
getStringValue
(
AllPredefinedLayers
.
AXIS_NAME
).
isPresent
();
int
axis
=
layer
.
getIntValue
(
AllPredefinedLayers
.
AXIS_NAME
).
get
();
int
channels
=
layer
.
getInputTypes
().
get
(
0
).
getChannels
();
int
height
=
layer
.
getInputTypes
().
get
(
0
).
getHeight
();
int
width
=
layer
.
getInputTypes
().
get
(
0
).
getWidth
();
if
(
axisIsNone
){
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
channels
*=
repeats
)
.
height
(
1
)
.
width
(
1
)
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
}
else
{
axis
=
layer
.
getIntValue
(
AllPredefinedLayers
.
AXIS_NAME
).
get
();
if
(
axis
==
0
){
height
*=
repeats
;
}
else
{
width
*=
repeats
;
}
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
channels
)
.
height
(
height
)
.
width
(
width
)
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
if
(
axis
==
0
)
{
channels
*=
repeats
;
}
else
if
(
axis
==
1
)
{
height
*=
repeats
;
}
else
if
(
axis
==
2
)
{
width
*=
repeats
;
}
else
{
// when no axis is given, expand dimension and repeat in new, first dimension
width
=
height
;
height
=
channels
;
channels
=
repeats
;
}
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
channels
)
.
height
(
height
)
.
width
(
width
)
.
elementType
(
layer
.
getInputTypes
().
get
(
0
).
getDomain
())
.
build
());
}
@Override
public
void
checkInput
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
errorIfInputSizeIsNotOne
(
inputTypes
,
layer
);
errorIfAxisNotFeasible
(
inputTypes
,
layer
);
int
axis
=
layer
.
getIntValue
(
AllPredefinedLayers
.
AXIS_NAME
).
get
();
if
(
axis
==
-
1
)
{
errorIfInputWidthIsInvalid
(
inputTypes
,
layer
,
1
);
}
}
public
static
Repeat
create
(){
...
...
@@ -85,10 +84,10 @@ public class Repeat extends PredefinedLayerDeclaration {
.
name
(
AllPredefinedLayers
.
REPEATS_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
build
(),
// no constraints in order to allow 'None' value
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
AXIS_NAME
)
.
defaultValue
(
1
)
.
constraints
(
Constraints
.
NULLABLE_AXIS
)
.
defaultValue
(-
1
)
.
build
()));
declaration
.
setParameters
(
parameters
);
return
declaration
;
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Squeeze.java
0 → 100644
View file @
09af254b
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnnarch.predefined
;
import
de.monticore.lang.monticar.cnnarch._symboltable.*
;
import
java.util.ArrayList
;
import
java.util.Arrays
;
import
java.util.Collections
;
import
java.util.List
;
public
class
Squeeze
extends
PredefinedLayerDeclaration
{
private
Squeeze
()
{
super
(
AllPredefinedLayers
.
SQUEEZE_NAME
);
}
@Override
public
List
<
ArchTypeSymbol
>
computeOutputTypes
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
List
<
Integer
>
dimensions
=
layer
.
getInputTypes
().
get
(
0
).
getDimensions
();
int
axis
=
layer
.
getIntValue
(
AllPredefinedLayers
.
AXIS_NAME
).
get
();
if
(
axis
==
-
1
)
{
dimensions
.
remove
(
new
Integer
(
1
));
}
else
{
dimensions
.
remove
(
axis
);
}
while
(
dimensions
.
size
()
<
3
)
{
dimensions
.
add
(
1
);
}
int
channels
=
dimensions
.
get
(
0
);
int
height
=
dimensions
.
get
(
1
);
int
width
=
dimensions
.
get
(
2
);
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
channels
)
.
height
(
height
)
.
width
(
width
)
.
elementType
(
layer
.
getInputTypes
().
get
(
0
).
getDomain
())
.
build
());
}