Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
languages
CNNArchLang
Commits
f4aa25ae
Commit
f4aa25ae
authored
Aug 16, 2019
by
Sebastian Nickels
Browse files
Added LSTM and GRU layers
parent
f8e3572d
Pipeline
#172312
failed with stages
Changes
5
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java
View file @
f4aa25ae
...
...
@@ -47,6 +47,8 @@ public class AllPredefinedLayers {
public
static
final
String
ONE_HOT_NAME
=
"OneHot"
;
public
static
final
String
BEAMSEARCH_NAME
=
"BeamSearchStart"
;
public
static
final
String
RNN_NAME
=
"RNN"
;
public
static
final
String
LSTM_NAME
=
"LSTM"
;
public
static
final
String
GRU_NAME
=
"GRU"
;
//predefined argument names
public
static
final
String
KERNEL_NAME
=
"kernel"
;
...
...
@@ -97,7 +99,9 @@ public class AllPredefinedLayers {
Add
.
create
(),
Concatenate
.
create
(),
OneHot
.
create
(),
RNN
.
create
());
RNN
.
create
(),
LSTM
.
create
(),
GRU
.
create
());
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/BaseRNN.java
0 → 100644
View file @
f4aa25ae
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnnarch.predefined
;
import
de.monticore.lang.monticar.cnnarch._symboltable.*
;
import
java.util.Collections
;
import
java.util.List
;
abstract
public
class
BaseRNN
extends
PredefinedLayerDeclaration
{
public
BaseRNN
(
String
name
)
{
super
(
name
);
}
@Override
public
boolean
isTrainable
(
VariableSymbol
.
Member
member
)
{
return
member
==
VariableSymbol
.
Member
.
NONE
;
}
@Override
public
List
<
ArchTypeSymbol
>
computeOutputTypes
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
int
units
=
layer
.
getIntValue
(
AllPredefinedLayers
.
UNITS_NAME
).
get
();
if
(
member
==
VariableSymbol
.
Member
.
STATE
)
{
int
layers
=
layer
.
getIntValue
(
AllPredefinedLayers
.
LAYERS_NAME
).
get
();
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
layers
)
.
height
(
units
)
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
}
else
{
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
layer
.
getInputTypes
().
get
(
0
).
getChannels
())
.
height
(
units
)
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
}
}
@Override
public
void
checkInput
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
int
units
=
layer
.
getIntValue
(
AllPredefinedLayers
.
UNITS_NAME
).
get
();
int
layers
=
layer
.
getIntValue
(
AllPredefinedLayers
.
LAYERS_NAME
).
get
();
if
(
member
==
VariableSymbol
.
Member
.
STATE
)
{
errorIfInputSizeIsNotOne
(
inputTypes
,
layer
);
errorIfInputChannelSizeIsInvalid
(
inputTypes
,
layer
,
layers
);
errorIfInputHeightIsInvalid
(
inputTypes
,
layer
,
units
);
errorIfInputWidthIsInvalid
(
inputTypes
,
layer
,
1
);
}
else
{
errorIfInputSizeIsNotOne
(
inputTypes
,
layer
);
errorIfInputChannelSizeIsInvalid
(
inputTypes
,
layer
,
layer
.
getInputTypes
().
get
(
0
).
getChannels
());
errorIfInputWidthIsInvalid
(
inputTypes
,
layer
,
1
);
}
}
@Override
public
boolean
isValidMember
(
VariableSymbol
.
Member
member
)
{
return
member
==
VariableSymbol
.
Member
.
NONE
||
member
==
VariableSymbol
.
Member
.
OUTPUT
||
member
==
VariableSymbol
.
Member
.
STATE
;
}
@Override
public
boolean
canBeInput
(
VariableSymbol
.
Member
member
)
{
return
member
==
VariableSymbol
.
Member
.
OUTPUT
||
member
==
VariableSymbol
.
Member
.
STATE
;
}
@Override
public
boolean
canBeOutput
(
VariableSymbol
.
Member
member
)
{
return
member
==
VariableSymbol
.
Member
.
NONE
||
member
==
VariableSymbol
.
Member
.
STATE
;
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/GRU.java
0 → 100644
View file @
f4aa25ae
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnnarch.predefined
;
import
de.monticore.lang.monticar.cnnarch._symboltable.Constraints
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ParameterSymbol
;
import
java.util.ArrayList
;
import
java.util.Arrays
;
import
java.util.List
;
public
class
GRU
extends
BaseRNN
{
private
GRU
()
{
super
(
AllPredefinedLayers
.
GRU_NAME
);
}
public
static
GRU
create
()
{
GRU
declaration
=
new
GRU
();
List
<
ParameterSymbol
>
parameters
=
new
ArrayList
<>(
Arrays
.
asList
(
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
UNITS_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
LAYERS_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
defaultValue
(
1
)
.
build
()));
declaration
.
setParameters
(
parameters
);
return
declaration
;
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/LSTM.java
0 → 100644
View file @
f4aa25ae
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnnarch.predefined
;
import
de.monticore.lang.monticar.cnnarch._symboltable.Constraints
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ParameterSymbol
;
import
java.util.ArrayList
;
import
java.util.Arrays
;
import
java.util.List
;
public
class
LSTM
extends
BaseRNN
{
private
LSTM
()
{
super
(
AllPredefinedLayers
.
LSTM_NAME
);
}
public
static
LSTM
create
()
{
LSTM
declaration
=
new
LSTM
();
List
<
ParameterSymbol
>
parameters
=
new
ArrayList
<>(
Arrays
.
asList
(
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
UNITS_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
LAYERS_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
defaultValue
(
1
)
.
build
()));
declaration
.
setParameters
(
parameters
);
return
declaration
;
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/RNN.java
View file @
f4aa25ae
...
...
@@ -24,74 +24,12 @@ import de.monticore.lang.monticar.cnnarch._symboltable.*;
import
java.util.*
;
public
class
RNN
extends
PredefinedLayerDeclaration
{
public
class
RNN
extends
BaseRNN
{
private
RNN
()
{
super
(
AllPredefinedLayers
.
RNN_NAME
);
}
@Override
public
boolean
isTrainable
(
VariableSymbol
.
Member
member
)
{
return
member
==
VariableSymbol
.
Member
.
NONE
;
}
@Override
public
List
<
ArchTypeSymbol
>
computeOutputTypes
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
int
units
=
layer
.
getIntValue
(
AllPredefinedLayers
.
UNITS_NAME
).
get
();
if
(
member
==
VariableSymbol
.
Member
.
STATE
)
{
int
layers
=
layer
.
getIntValue
(
AllPredefinedLayers
.
LAYERS_NAME
).
get
();
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
layers
)
.
height
(
units
)
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
}
else
{
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
layer
.
getInputTypes
().
get
(
0
).
getChannels
())
.
height
(
units
)
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
}
}
@Override
public
void
checkInput
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
int
units
=
layer
.
getIntValue
(
AllPredefinedLayers
.
UNITS_NAME
).
get
();
int
layers
=
layer
.
getIntValue
(
AllPredefinedLayers
.
LAYERS_NAME
).
get
();
if
(
member
==
VariableSymbol
.
Member
.
STATE
)
{
errorIfInputSizeIsNotOne
(
inputTypes
,
layer
);
errorIfInputChannelSizeIsInvalid
(
inputTypes
,
layer
,
layers
);
errorIfInputHeightIsInvalid
(
inputTypes
,
layer
,
units
);
errorIfInputWidthIsInvalid
(
inputTypes
,
layer
,
1
);
}
else
{
errorIfInputSizeIsNotOne
(
inputTypes
,
layer
);
errorIfInputChannelSizeIsInvalid
(
inputTypes
,
layer
,
layer
.
getInputTypes
().
get
(
0
).
getChannels
());
errorIfInputWidthIsInvalid
(
inputTypes
,
layer
,
1
);
}
}
@Override
public
boolean
isValidMember
(
VariableSymbol
.
Member
member
)
{
return
member
==
VariableSymbol
.
Member
.
NONE
||
member
==
VariableSymbol
.
Member
.
OUTPUT
||
member
==
VariableSymbol
.
Member
.
STATE
;
}
@Override
public
boolean
canBeInput
(
VariableSymbol
.
Member
member
)
{
return
member
==
VariableSymbol
.
Member
.
OUTPUT
||
member
==
VariableSymbol
.
Member
.
STATE
;
}
@Override
public
boolean
canBeOutput
(
VariableSymbol
.
Member
member
)
{
return
member
==
VariableSymbol
.
Member
.
NONE
||
member
==
VariableSymbol
.
Member
.
STATE
;
}
public
static
RNN
create
()
{
RNN
declaration
=
new
RNN
();
List
<
ParameterSymbol
>
parameters
=
new
ArrayList
<>(
Arrays
.
asList
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment