Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
languages
CNNArchLang
Commits
7712a02f
Commit
7712a02f
authored
Aug 23, 2019
by
Christian Fuß
Browse files
some small fixes to UnrollSymbol
parent
8b12e410
Pipeline
#175296
passed with stages
in 18 minutes and 35 seconds
Changes
10
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/main/grammars/de/monticore/lang/monticar/CNNArch.mc4
View file @
7712a02f
...
...
@@ -62,7 +62,7 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
Stream
=
elements
:(
ArchitectureElement
||
"->"
)+;
Unroll
=
"timed"
"<"
timeParameter
:
Architecture
Parameter
">"
Unroll
=
"timed"
"<"
timeParameter
:
Layer
Parameter
">"
Name
"("
arguments
:(
ArchArgument
||
","
)*
")"
"{"
body
:
Stream
"}"
;
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/_cocos/CheckLayer.java
View file @
7712a02f
...
...
@@ -52,63 +52,29 @@ public class CheckLayer implements CNNArchASTLayerCoCo{
nameSet
.
add
(
name
);
}
}
if
(
node
.
getSymbolOpt
().
get
()
instanceof
LayerSymbol
)
{
LayerDeclarationSymbol
layerDeclaration
=
((
LayerSymbol
)
node
.
getSymbolOpt
().
get
()).
getDeclaration
();
if
(
layerDeclaration
==
null
){
ArchitectureSymbol
architecture
=
node
.
getSymbolOpt
().
get
().
getEnclosingScope
().<
ArchitectureSymbol
>
resolve
(
""
,
ArchitectureSymbol
.
KIND
).
get
();
Log
.
error
(
"0"
+
ErrorCodes
.
UNKNOWN_LAYER
+
" Unknown layer. "
+
"Layer with name '"
+
node
.
getName
()
+
"' does not exist. "
+
"Existing layers: "
+
Joiners
.
COMMA
.
join
(
architecture
.
getLayerDeclarations
())
+
"."
,
node
.
get_SourcePositionStart
());
}
else
{
Set
<
String
>
requiredArguments
=
new
HashSet
<>();
for
(
ParameterSymbol
param
:
layerDeclaration
.
getParameters
()){
if
(!
param
.
getDefaultExpression
().
isPresent
()){
requiredArguments
.
add
(
param
.
getName
());
}
}
for
(
ASTArchArgument
argument
:
node
.
getArgumentsList
()){
requiredArguments
.
remove
(
argument
.
getName
());
}
for
(
String
missingArgumentName
:
requiredArguments
){
Log
.
error
(
"0"
+
ErrorCodes
.
MISSING_ARGUMENT
+
" Missing argument. "
+
"The argument '"
+
missingArgumentName
+
"' is required."
,
node
.
get_SourcePositionStart
());
LayerDeclarationSymbol
layerDeclaration
=
((
LayerSymbol
)
node
.
getSymbolOpt
().
get
()).
getDeclaration
();
if
(
layerDeclaration
==
null
){
ArchitectureSymbol
architecture
=
node
.
getSymbolOpt
().
get
().
getEnclosingScope
().<
ArchitectureSymbol
>
resolve
(
""
,
ArchitectureSymbol
.
KIND
).
get
();
Log
.
error
(
"0"
+
ErrorCodes
.
UNKNOWN_LAYER
+
" Unknown layer. "
+
"Layer with name '"
+
node
.
getName
()
+
"' does not exist. "
+
"Existing layers: "
+
Joiners
.
COMMA
.
join
(
architecture
.
getLayerDeclarations
())
+
"."
,
node
.
get_SourcePositionStart
());
}
else
{
Set
<
String
>
requiredArguments
=
new
HashSet
<>();
for
(
ParameterSymbol
param
:
layerDeclaration
.
getParameters
()){
if
(!
param
.
getDefaultExpression
().
isPresent
()){
requiredArguments
.
add
(
param
.
getName
());
}
}
}
else
{
UnrollDeclarationSymbol
unrollDeclaration
=
((
UnrollSymbol
)
node
.
getSymbolOpt
().
get
()).
getDeclaration
();
if
(
unrollDeclaration
==
null
){
ArchitectureSymbol
architecture
=
node
.
getSymbolOpt
().
get
().
getEnclosingScope
().<
ArchitectureSymbol
>
resolve
(
""
,
ArchitectureSymbol
.
KIND
).
get
();
Log
.
error
(
"0"
+
ErrorCodes
.
UNKNOWN_LAYER
+
" Unknown layer. "
+
"Layer with name '"
+
node
.
getName
()
+
"' does not exist. "
+
"Existing layers: "
+
Joiners
.
COMMA
.
join
(
architecture
.
getLayerDeclarations
())
+
"."
,
node
.
get_SourcePositionStart
());
for
(
ASTArchArgument
argument
:
node
.
getArgumentsList
()){
requiredArguments
.
remove
(
argument
.
getName
());
}
else
{
Set
<
String
>
requiredArguments
=
new
HashSet
<>();
for
(
ParameterSymbol
param
:
unrollDeclaration
.
getParameters
()){
if
(!
param
.
getDefaultExpression
().
isPresent
()){
requiredArguments
.
add
(
param
.
getName
());
}
}
for
(
ASTArchArgument
argument
:
node
.
getArgumentsList
()){
requiredArguments
.
remove
(
argument
.
getName
());
}
for
(
String
missingArgumentName
:
requiredArguments
){
Log
.
error
(
"0"
+
ErrorCodes
.
MISSING_ARGUMENT
+
" Missing argument. "
+
"The argument '"
+
missingArgumentName
+
"' is required."
,
node
.
get_SourcePositionStart
());
}
for
(
LayerSymbol
sublayer:
unrollDeclaration
.
getLayers
()){
check
((
ASTLayer
)
sublayer
.
getAstNode
().
get
());
}
for
(
String
missingArgumentName
:
requiredArguments
){
Log
.
error
(
"0"
+
ErrorCodes
.
MISSING_ARGUMENT
+
" Missing argument. "
+
"The argument '"
+
missingArgumentName
+
"' is required."
,
node
.
get_SourcePositionStart
());
}
}
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArgumentSymbol.java
View file @
7712a02f
...
...
@@ -104,6 +104,7 @@ public class ArgumentSymbol extends CommonSymbol {
public
void
set
(){
if
(
getRhs
().
isResolved
()
&&
getRhs
().
isSimpleValue
()){
getParameter
().
setExpression
((
ArchSimpleExpressionSymbol
)
getRhs
());
getUnrollParameter
().
setExpression
((
ArchSimpleExpressionSymbol
)
getRhs
());
}
else
{
throw
new
IllegalStateException
(
"The value of the parameter is set to a sequence or the expression is not resolved. This should never happen."
);
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/CNNArchSymbolTableCreator.java
View file @
7712a02f
...
...
@@ -31,6 +31,7 @@ import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedVariables;
import
de.monticore.symboltable.*
;
import
de.se_rwth.commons.logging.Log
;
import
java.lang.reflect.Array
;
import
java.util.*
;
public
class
CNNArchSymbolTableCreator
extends
de
.
monticore
.
symboltable
.
CommonSymbolTableCreator
...
...
@@ -346,26 +347,15 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
public
void
endVisit
(
ASTUnroll
ast
)
{
UnrollSymbol
layer
=
(
UnrollSymbol
)
ast
.
getSymbolOpt
().
get
();
layer
.
setBody
((
SerialCompositeElementSymbol
)
ast
.
getBody
().
getSymbolOpt
().
get
());
//layer.getDeclaration().setBody(sces);
List
<
ArgumentSymbol
>
arguments
=
new
ArrayList
<>(
6
);
//ast.getArgumentsList().add(ast.getTimeParameter());
for
(
ASTArchArgument
astArgument
:
ast
.
getArgumentsList
()){
Optional
<
ArgumentSymbol
>
optArgument
=
astArgument
.
getSymbolOpt
().
map
(
e
->
(
ArgumentSymbol
)
e
);
optArgument
.
ifPresent
(
arguments:
:
add
);
}
layer
.
setArguments
(
arguments
);
/*List<ArchitectureElementSymbol> elements = new ArrayList<>();
for (ASTStream astStream : ast.getGroupsList()){
elements.add((SerialCompositeElementSymbol) astStream.getSymbolOpt().get());
}
compositeElement.setElements(elements);
*/
removeCurrentScope
();
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/UnrollSymbol.java
View file @
7712a02f
...
...
@@ -21,6 +21,7 @@
package
de.monticore.lang.monticar.cnnarch._symboltable
;
import
de.monticore.lang.monticar.cnnarch._ast.ASTArchitectureParameter
;
import
de.monticore.lang.monticar.cnnarch.helper.ErrorCodes
;
import
de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers
;
import
de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedVariables
;
...
...
@@ -39,7 +40,7 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
private
UnrollDeclarationSymbol
declaration
=
null
;
private
List
<
ArgumentSymbol
>
arguments
;
private
Set
<
ParameterSymbol
>
unresolvabl
eParameter
s
=
null
;
private
ParameterSymbol
tim
eParameter
;
private
UnrollSymbol
resolvedThis
=
null
;
private
SerialCompositeElementSymbol
body
;
...
...
@@ -86,6 +87,14 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
this
.
arguments
=
arguments
;
}
public
ParameterSymbol
getTimeParameter
(){
return
timeParameter
;
}
protected
void
setTimeParameter
(
ParameterSymbol
timeParameter
){
this
.
timeParameter
=
timeParameter
;
}
public
ArchExpressionSymbol
getIfExpression
(){
Optional
<
ArgumentSymbol
>
argument
=
getArgument
(
AllPredefinedVariables
.
CONDITIONAL_ARG_NAME
);
if
(
argument
.
isPresent
()){
...
...
@@ -222,6 +231,50 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
}
}
public
void
setIntValue
(
String
parameterName
,
int
value
)
{
setTValue
(
parameterName
,
value
,
ArchSimpleExpressionSymbol:
:
of
);
}
public
void
setIntTupleValue
(
String
parameterName
,
List
<
Object
>
tupleValues
)
{
setTValue
(
parameterName
,
tupleValues
,
ArchSimpleExpressionSymbol:
:
of
);
}
public
void
setBooleanValue
(
String
parameterName
,
boolean
value
)
{
setTValue
(
parameterName
,
value
,
ArchSimpleExpressionSymbol:
:
of
);
}
public
void
setStringValue
(
String
parameterName
,
String
value
)
{
setTValue
(
parameterName
,
value
,
ArchSimpleExpressionSymbol:
:
of
);
}
public
void
setDoubleValue
(
String
parameterName
,
double
value
)
{
setTValue
(
parameterName
,
value
,
ArchSimpleExpressionSymbol:
:
of
);
}
public
void
setValue
(
String
parameterName
,
Object
value
)
{
ArchSimpleExpressionSymbol
res
=
new
ArchSimpleExpressionSymbol
();
res
.
setValue
(
value
);
setTValue
(
parameterName
,
res
,
Function
.
identity
());
}
public
<
T
>
void
setTValue
(
String
parameterName
,
T
value
,
Function
<
T
,
ArchSimpleExpressionSymbol
>
of
)
{
Optional
<
ParameterSymbol
>
param
=
getDeclaration
().
getParameter
(
parameterName
);
if
(
param
.
isPresent
())
{
Optional
<
ArgumentSymbol
>
arg
=
getArgument
(
parameterName
);
ArchSimpleExpressionSymbol
expression
=
of
.
apply
(
value
);
if
(
arg
.
isPresent
())
{
arg
.
get
().
setRhs
(
expression
);
}
else
{
arg
=
Optional
.
of
(
new
ArgumentSymbol
(
parameterName
));
arg
.
get
().
setRhs
(
expression
);
arguments
.
add
(
arg
.
get
());
}
}
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java
View file @
7712a02f
...
...
@@ -74,6 +74,7 @@ public class AllPredefinedLayers {
public
static
final
String
OUTPUT_DIM_NAME
=
"output_dim"
;
public
static
final
String
BEAMSEARCH_MAX_LENGTH
=
"max_length"
;
public
static
final
String
BEAMSEARCH_WIDTH_NAME
=
"width"
;
public
static
final
String
BEAMSEARCH_T_NAME
=
"t"
;
//possible String values
public
static
final
String
PADDING_VALID
=
"valid"
;
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/BeamSearchStart.java
View file @
7712a02f
...
...
@@ -104,17 +104,15 @@ public class BeamSearchStart extends PredefinedUnrollDeclaration {
.
name
(
AllPredefinedLayers
.
BEAMSEARCH_MAX_LENGTH
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
BEAMSEARCH_T_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
NON_NEGATIVE
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
BEAMSEARCH_WIDTH_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
build
()));
declaration
.
setParameters
(
parameters
);
declaration
.
setLayers
(
declaration
.
layers
);
for
(
LayerSymbol
layer:
declaration
.
layers
){
for
(
ArgumentSymbol
a:
layer
.
getArguments
())
{
//layer.setIntValue(a.getName(), 10);
}
}
return
declaration
;
}
}
src/test/java/de/monticore/lang/monticar/cnnarch/cocos/AllCoCoTest.java
View file @
7712a02f
...
...
@@ -45,9 +45,9 @@ public class AllCoCoTest extends AbstractCoCoTest {
@Test
public
void
testValidCoCos
(){
checkValid
(
"valid_tests"
,
"RNNencdec"
);
checkValid
(
"architectures"
,
"ResNeXt50"
);
checkValid
(
"architectures"
,
"ResNet152"
);
checkValid
(
"architectures"
,
"Alexnet"
);
checkValid
(
"architectures"
,
"ResNeXt50"
);
checkValid
(
"architectures"
,
"ResNet34"
);
checkValid
(
"architectures"
,
"SequentialAlexnet"
);
checkValid
(
"architectures"
,
"ThreeInputCNN_M14"
);
...
...
src/test/resources/architectures/RNNsearch.cnna
View file @
7712a02f
...
...
@@ -14,7 +14,7 @@ architecture RNNsearch(max_length=50, vocabulary_size=30001, embedding_size=620,
1 -> OneHot(n=vocabulary_size) -> target[0];
encoder.state[1] -> decoder.state;
unroll<t> BeamSearchStart(width=5, max_length=50) {
unroll<t
=0
> BeamSearchStart(width=5, max_length=50) {
(
(
(
...
...
src/test/resources/valid_tests/RNNencdec.cnna
View file @
7712a02f
architecture RNNencdec(max_length=50, vocabulary_size=30000, hidden_size=1000){
def input Q(0:1)^{vocabulary_size} source
def output Q(0:1)^{vocabulary_size} target
def output Q(0:1)^{vocabulary_size} target
[3]
timed<t> BeamSearchStart(max_length=max_length) {
source ->
FullyConnected(units=17) ->
Softmax() ->
target
};
source -> Softmax() -> target[0];
timed <t=2> BeamSearchStart(max_length=3){
target[t-1] ->
FullyConnected(units=30000) ->
Softmax() ->
target[t]
};
}
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment