Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
languages
CNNArchLang
Commits
0eb98f6f
Commit
0eb98f6f
authored
Aug 18, 2019
by
Sebastian Nickels
Browse files
Implemented bidirectional RNNs
parent
afb3db82
Changes
5
Show whitespace changes
Inline
Side-by-side
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java
View file @
0eb98f6f
...
...
@@ -71,6 +71,7 @@ public class AllPredefinedLayers {
public
static
final
String
LAYERS_NAME
=
"layers"
;
public
static
final
String
INPUT_DIM_NAME
=
"input_dim"
;
public
static
final
String
OUTPUT_DIM_NAME
=
"output_dim"
;
public
static
final
String
BIDIRECTIONAL_NAME
=
"bidirectional"
;
public
static
final
String
BEAMSEARCH_MAX_LENGTH
=
"max_length"
;
public
static
final
String
BEAMSEARCH_WIDTH_NAME
=
"width"
;
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/BaseRNN.java
View file @
0eb98f6f
...
...
@@ -22,6 +22,8 @@ package de.monticore.lang.monticar.cnnarch.predefined;
import
de.monticore.lang.monticar.cnnarch._symboltable.*
;
import
java.util.ArrayList
;
import
java.util.Arrays
;
import
java.util.Collections
;
import
java.util.List
;
...
...
@@ -38,13 +40,14 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
@Override
public
List
<
ArchTypeSymbol
>
computeOutputTypes
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
boolean
bidirectional
=
layer
.
getBooleanValue
(
AllPredefinedLayers
.
BIDIRECTIONAL_NAME
).
get
();
int
units
=
layer
.
getIntValue
(
AllPredefinedLayers
.
UNITS_NAME
).
get
();
if
(
member
==
VariableSymbol
.
Member
.
STATE
)
{
int
layers
=
layer
.
getIntValue
(
AllPredefinedLayers
.
LAYERS_NAME
).
get
();
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
layers
)
.
channels
(
bidirectional
?
2
*
layers
:
layers
)
.
height
(
units
)
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
...
...
@@ -52,7 +55,7 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
else
{
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
layer
.
getInputTypes
().
get
(
0
).
getChannels
())
.
height
(
units
)
.
height
(
bidirectional
?
2
*
units
:
units
)
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
}
...
...
@@ -60,12 +63,13 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
@Override
public
void
checkInput
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
boolean
bidirectional
=
layer
.
getBooleanValue
(
AllPredefinedLayers
.
BIDIRECTIONAL_NAME
).
get
();
int
units
=
layer
.
getIntValue
(
AllPredefinedLayers
.
UNITS_NAME
).
get
();
int
layers
=
layer
.
getIntValue
(
AllPredefinedLayers
.
LAYERS_NAME
).
get
();
if
(
member
==
VariableSymbol
.
Member
.
STATE
)
{
errorIfInputSizeIsNotOne
(
inputTypes
,
layer
);
errorIfInputChannelSizeIsInvalid
(
inputTypes
,
layer
,
layers
);
errorIfInputChannelSizeIsInvalid
(
inputTypes
,
layer
,
bidirectional
?
2
*
layers
:
layers
);
errorIfInputHeightIsInvalid
(
inputTypes
,
layer
,
units
);
errorIfInputWidthIsInvalid
(
inputTypes
,
layer
,
1
);
}
...
...
@@ -92,4 +96,22 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
public
boolean
canBeOutput
(
VariableSymbol
.
Member
member
)
{
return
member
==
VariableSymbol
.
Member
.
NONE
||
member
==
VariableSymbol
.
Member
.
STATE
;
}
protected
static
List
<
ParameterSymbol
>
createParameters
()
{
return
new
ArrayList
<>(
Arrays
.
asList
(
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
UNITS_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
LAYERS_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
defaultValue
(
1
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
BIDIRECTIONAL_NAME
)
.
constraints
(
Constraints
.
BOOLEAN
)
.
defaultValue
(
false
)
.
build
()));
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/GRU.java
View file @
0eb98f6f
...
...
@@ -20,13 +20,6 @@
*/
package
de.monticore.lang.monticar.cnnarch.predefined
;
import
de.monticore.lang.monticar.cnnarch._symboltable.Constraints
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ParameterSymbol
;
import
java.util.ArrayList
;
import
java.util.Arrays
;
import
java.util.List
;
public
class
GRU
extends
BaseRNN
{
private
GRU
()
{
...
...
@@ -35,17 +28,7 @@ public class GRU extends BaseRNN {
public
static
GRU
create
()
{
GRU
declaration
=
new
GRU
();
List
<
ParameterSymbol
>
parameters
=
new
ArrayList
<>(
Arrays
.
asList
(
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
UNITS_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
LAYERS_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
defaultValue
(
1
)
.
build
()));
declaration
.
setParameters
(
parameters
);
declaration
.
setParameters
(
createParameters
());
return
declaration
;
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/LSTM.java
View file @
0eb98f6f
...
...
@@ -20,13 +20,6 @@
*/
package
de.monticore.lang.monticar.cnnarch.predefined
;
import
de.monticore.lang.monticar.cnnarch._symboltable.Constraints
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ParameterSymbol
;
import
java.util.ArrayList
;
import
java.util.Arrays
;
import
java.util.List
;
public
class
LSTM
extends
BaseRNN
{
private
LSTM
()
{
...
...
@@ -35,17 +28,7 @@ public class LSTM extends BaseRNN {
public
static
LSTM
create
()
{
LSTM
declaration
=
new
LSTM
();
List
<
ParameterSymbol
>
parameters
=
new
ArrayList
<>(
Arrays
.
asList
(
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
UNITS_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
LAYERS_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
defaultValue
(
1
)
.
build
()));
declaration
.
setParameters
(
parameters
);
declaration
.
setParameters
(
createParameters
());
return
declaration
;
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/RNN.java
View file @
0eb98f6f
...
...
@@ -20,10 +20,6 @@
*/
package
de.monticore.lang.monticar.cnnarch.predefined
;
import
de.monticore.lang.monticar.cnnarch._symboltable.*
;
import
java.util.*
;
public
class
RNN
extends
BaseRNN
{
private
RNN
()
{
...
...
@@ -32,17 +28,7 @@ public class RNN extends BaseRNN {
public
static
RNN
create
()
{
RNN
declaration
=
new
RNN
();
List
<
ParameterSymbol
>
parameters
=
new
ArrayList
<>(
Arrays
.
asList
(
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
UNITS_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
LAYERS_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
defaultValue
(
1
)
.
build
()));
declaration
.
setParameters
(
parameters
);
declaration
.
setParameters
(
createParameters
());
return
declaration
;
}
}
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment