Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
languages
CNNArchLang
Commits
a5b546e2
Commit
a5b546e2
authored
Jul 03, 2019
by
Sebastian Nickels
Browse files
Changed OneHot layer and added support for constants
parent
6e5d8673
Pipeline
#155895
passed with stages
in 18 minutes and 44 seconds
Changes
8
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/main/grammars/de/monticore/lang/monticar/CNNArch.mc4
View file @
a5b546e2
...
...
@@ -69,6 +69,8 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
IOElement
implements
ArchitectureElement
=
Name
(
"["
index
:
ArchSimpleExpression
"]"
)?;
Constant
implements
ArchitectureElement
=
ArchSimpleExpression
;
Layer
implements
ArchitectureElement
=
Name
"("
arguments
:(
ArchArgument
||
","
)*
")"
;
Unroll
implements
ArchitectureElement
=
"unroll"
"<"
timeParameter
:
LayerParameter
">"
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/CNNArchSymbolTableCreator.java
View file @
a5b546e2
...
...
@@ -20,7 +20,6 @@
*/
package
de.monticore.lang.monticar.cnnarch._symboltable
;
import
de.monticore.expressionsbasis._ast.ASTExpression
;
import
de.monticore.lang.math._symboltable.MathSymbolTableCreator
;
import
de.monticore.lang.math._symboltable.expression.*
;
...
...
@@ -427,6 +426,17 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
addToScopeAndLinkWithNode
(
argument
,
node
);
}
public
void
visit
(
ASTConstant
node
)
{
ConstantSymbol
constant
=
new
ConstantSymbol
();
addToScopeAndLinkWithNode
(
constant
,
node
);
}
public
void
endVisit
(
ASTConstant
node
)
{
ConstantSymbol
constant
=
(
ConstantSymbol
)
node
.
getSymbolOpt
().
get
();
constant
.
setExpression
((
ArchSimpleExpressionSymbol
)
node
.
getArchSimpleExpression
().
getSymbolOpt
().
get
());
removeCurrentScope
();
}
public
void
visit
(
ASTIOElement
node
)
{
IOSymbol
ioElement
=
new
IOSymbol
(
node
.
getName
());
addToScopeAndLinkWithNode
(
ioElement
,
node
);
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ConstantSymbol.java
0 → 100644
View file @
a5b546e2
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnnarch._symboltable
;
import
de.monticore.lang.math._symboltable.expression.MathExpressionSymbol
;
import
de.monticore.lang.monticar.cnnarch.helper.ErrorCodes
;
import
de.monticore.lang.monticar.cnnarch.helper.Utils
;
import
de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers
;
import
de.monticore.lang.monticar.ranges._ast.ASTRange
;
import
de.monticore.lang.monticar.types2._ast.ASTElementType
;
import
de.monticore.symboltable.Scope
;
import
de.monticore.symboltable.Symbol
;
import
de.se_rwth.commons.logging.Log
;
import
java.util.*
;
public
class
ConstantSymbol
extends
ArchitectureElementSymbol
{
private
ArchSimpleExpressionSymbol
expression
=
null
;
protected
ConstantSymbol
()
{
super
(
"const"
);
}
public
ArchSimpleExpressionSymbol
getExpression
()
{
return
expression
;
}
protected
void
setExpression
(
ArchSimpleExpressionSymbol
expression
)
{
this
.
expression
=
expression
;
}
@Override
public
boolean
isResolvable
()
{
return
super
.
isResolvable
();
}
@Override
public
boolean
isAtomic
()
{
return
getResolvedThis
().
isPresent
()
&&
getResolvedThis
().
get
()
==
this
;
}
@Override
public
List
<
ArchitectureElementSymbol
>
getFirstAtomicElements
()
{
if
(
getResolvedThis
().
isPresent
()
&&
getResolvedThis
().
get
()
!=
this
)
{
return
getResolvedThis
().
get
().
getFirstAtomicElements
();
}
else
{
return
Collections
.
singletonList
(
this
);
}
}
@Override
public
List
<
ArchitectureElementSymbol
>
getLastAtomicElements
()
{
if
(
getResolvedThis
().
isPresent
()
&&
getResolvedThis
().
get
()
!=
this
)
{
return
getResolvedThis
().
get
().
getLastAtomicElements
();
}
else
{
return
Collections
.
singletonList
(
this
);
}
}
@Override
public
Set
<
VariableSymbol
>
resolve
()
throws
ArchResolveException
{
if
(!
isResolved
())
{
if
(
isResolvable
())
{
resolveExpressions
();
setResolvedThis
(
this
);
}
}
return
getUnresolvableVariables
();
}
@Override
protected
void
computeUnresolvableVariables
(
Set
<
VariableSymbol
>
unresolvableVariables
,
Set
<
VariableSymbol
>
allVariables
)
{
getExpression
().
checkIfResolvable
(
allVariables
);
unresolvableVariables
.
addAll
(
getExpression
().
getUnresolvableVariables
());
}
@Override
public
List
<
ArchTypeSymbol
>
computeOutputTypes
()
{
List
<
ArchTypeSymbol
>
outputShapes
;
if
(
isAtomic
())
{
ArchTypeSymbol
outputShape
=
new
ArchTypeSymbol
();
// Since symbol is resolved at this point, it is safe to assume that the expression is an int
int
value
=
getExpression
().
getIntValue
().
get
();
ASTRange
range
=
new
ASTRange
();
range
.
setStartValue
(
String
.
valueOf
(
value
));
range
.
setEndValue
(
String
.
valueOf
(
value
));
ASTElementType
domain
=
new
ASTElementType
(
"Z"
,
Optional
.
of
(
range
));
outputShape
.
setDomain
(
domain
);
outputShapes
=
Collections
.
singletonList
(
outputShape
);
}
else
{
if
(!
getResolvedThis
().
isPresent
()){
throw
new
IllegalStateException
(
"The architecture resolve() method was never called"
);
}
outputShapes
=
getResolvedThis
().
get
().
computeOutputTypes
();
}
return
outputShapes
;
}
@Override
public
void
checkInput
()
{
if
(
isAtomic
())
{
if
(!
getInputTypes
().
isEmpty
())
{
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_ELEMENT_INPUT_SHAPE
+
" Invalid number of input streams. "
,
getSourcePosition
());
}
}
else
{
if
(!
getResolvedThis
().
isPresent
())
{
throw
new
IllegalStateException
(
"The architecture resolve() method was never called"
);
}
getResolvedThis
().
get
().
checkInput
();
}
}
@Override
public
Optional
<
Integer
>
getParallelLength
()
{
return
Optional
.
of
(
1
);
}
@Override
public
Optional
<
List
<
Integer
>>
getSerialLengths
()
{
return
Optional
.
of
(
Collections
.
nCopies
(
getParallelLength
().
get
(),
1
));
}
@Override
protected
void
putInScope
(
Scope
scope
)
{
Collection
<
Symbol
>
symbolsInScope
=
scope
.
getLocalSymbols
().
get
(
getName
());
if
(
symbolsInScope
==
null
||
!
symbolsInScope
.
contains
(
this
))
{
scope
.
getAsMutableScope
().
add
(
this
);
getExpression
().
putInScope
(
getSpannedScope
());
}
}
@Override
protected
void
resolveExpressions
()
throws
ArchResolveException
{
getExpression
().
resolveOrError
();
if
(!
Constraints
.
INTEGER
.
check
(
getExpression
(),
getSourcePosition
(),
getName
()))
{
throw
new
ArchResolveException
();
}
}
@Override
protected
ArchitectureElementSymbol
preResolveDeepCopy
()
{
ConstantSymbol
copy
=
new
ConstantSymbol
();
if
(
getAstNode
().
isPresent
())
{
copy
.
setAstNode
(
getAstNode
().
get
());
}
copy
.
setExpression
(
getExpression
());
return
copy
;
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/PredefinedLayerDeclaration.java
View file @
a5b546e2
...
...
@@ -91,6 +91,36 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
}
}
protected
void
errorIfInputChannelSizeIsInvalid
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
int
channels
)
{
for
(
ArchTypeSymbol
inputType
:
inputTypes
)
{
if
(
inputType
.
getChannels
()
!=
channels
)
{
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_ELEMENT_INPUT_SHAPE
+
" Invalid layer input. Input channel size is "
+
inputType
.
getChannels
()
+
" but needs to be "
+
channels
+
"."
,
layer
.
getSourcePosition
());
}
}
}
protected
void
errorIfInputHeightIsInvalid
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
int
height
)
{
for
(
ArchTypeSymbol
inputType
:
inputTypes
)
{
if
(
inputType
.
getHeight
()
!=
height
)
{
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_ELEMENT_INPUT_SHAPE
+
" Invalid layer input. Input height is "
+
inputType
.
getHeight
()
+
" but needs to be "
+
height
+
"."
,
layer
.
getSourcePosition
());
}
}
}
protected
void
errorIfInputWidthIsInvalid
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
int
width
)
{
for
(
ArchTypeSymbol
inputType
:
inputTypes
)
{
if
(
inputType
.
getWidth
()
!=
width
)
{
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_ELEMENT_INPUT_SHAPE
+
" Invalid layer input. Input width is "
+
inputType
.
getWidth
()
+
" but needs to be "
+
width
+
"."
,
layer
.
getSourcePosition
());
}
}
}
//check input for convolution and pooling
protected
static
void
errorIfInputSmallerThanKernel
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
){
if
(!
inputTypes
.
isEmpty
())
{
...
...
@@ -116,22 +146,6 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
}
}
//check input for onehot layer
protected
static
void
errorIfInputSizeUnequalToOnehotSize
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
){
if
(!
inputTypes
.
isEmpty
()
&&
layer
.
getIntValue
(
AllPredefinedLayers
.
ONE_HOT_SIZE_NAME
).
get
()
!=
0
)
{
int
inputChannels
=
inputTypes
.
get
(
0
).
getChannels
();
int
onehotSize
=
layer
.
getIntValue
(
AllPredefinedLayers
.
ONE_HOT_SIZE_NAME
).
get
();
if
(
onehotSize
!=
inputChannels
){
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_ELEMENT_INPUT_SHAPE
+
"The size of the onehot vector is not equal to the output size of the previous layer."
+
"This is usually not intended."
,
layer
.
getSourcePosition
());
}
}
}
//output type function for convolution and pooling
protected
static
List
<
ArchTypeSymbol
>
computeConvAndPoolOutputShape
(
ArchTypeSymbol
inputType
,
LayerSymbol
method
,
int
channels
)
{
String
borderModeSetting
=
method
.
getStringValue
(
AllPredefinedLayers
.
PADDING_NAME
).
get
();
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java
View file @
a5b546e2
...
...
@@ -63,7 +63,7 @@ public class AllPredefinedLayers {
public
static
final
String
BETA_NAME
=
"beta"
;
public
static
final
String
PADDING_NAME
=
"padding"
;
public
static
final
String
POOL_TYPE_NAME
=
"pool_type"
;
public
static
final
String
ONE_HOT_
SIZE_NAME
=
"size"
;
public
static
final
String
SIZE_NAME
=
"size"
;
public
static
final
String
BEAMSEARCH_MAX_LENGTH
=
"max_length"
;
public
static
final
String
BEAMSEARCH_WIDTH_NAME
=
"width"
;
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/OneHot.java
View file @
a5b546e2
...
...
@@ -21,11 +21,13 @@
package
de.monticore.lang.monticar.cnnarch.predefined
;
import
de.monticore.lang.monticar.cnnarch._symboltable.*
;
import
de.monticore.lang.monticar.cnnarch.helper.ErrorCodes
;
import
de.monticore.lang.monticar.ranges._ast.ASTRange
;
import
de.monticore.lang.monticar.ranges._ast.ASTRangeStepResolution
;
import
de.monticore.lang.monticar.types2._ast.ASTElementType
;
import
de.se_rwth.commons.logging.Log
;
import
java.util.ArrayList
;
import
java.util.Arrays
;
import
java.util.Collections
;
import
java.util.List
;
import
java.util.*
;
public
class
OneHot
extends
PredefinedLayerDeclaration
{
...
...
@@ -35,14 +37,12 @@ public class OneHot extends PredefinedLayerDeclaration {
super
(
AllPredefinedLayers
.
ONE_HOT_NAME
);
}
@Override
public
List
<
ArchTypeSymbol
>
computeOutputTypes
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
)
{
channels
=
layer
.
getIntValue
(
AllPredefinedLayers
.
ONE_HOT_SIZE_NAME
).
get
();
channels
=
layer
.
getIntValue
(
AllPredefinedLayers
.
SIZE_NAME
).
get
();
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
layer
.
getIntValue
(
AllPredefinedLayers
.
ONE_HOT_
SIZE_NAME
).
get
())
.
channels
(
layer
.
getIntValue
(
AllPredefinedLayers
.
SIZE_NAME
).
get
())
.
height
(
1
)
.
width
(
1
)
.
elementType
(
"0"
,
"1"
)
...
...
@@ -52,14 +52,58 @@ public class OneHot extends PredefinedLayerDeclaration {
@Override
public
void
checkInput
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
)
{
errorIfInputSizeIsNotOne
(
inputTypes
,
layer
);
errorIfInputSizeUnequalToOnehotSize
(
inputTypes
,
layer
);
errorIfInputChannelSizeIsInvalid
(
inputTypes
,
layer
,
1
);
errorIfInputHeightIsInvalid
(
inputTypes
,
layer
,
1
);
errorIfInputWidthIsInvalid
(
inputTypes
,
layer
,
1
);
// Check range of input
ASTElementType
domain
=
inputTypes
.
get
(
0
).
getDomain
();
if
(!
domain
.
isNaturalNumber
()
&&
!
domain
.
isWholeNumber
())
{
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_ELEMENT_INPUT_DOMAIN
+
" Invalid layer input domain: Input needs to be natural or whole. "
,
layer
.
getSourcePosition
());
}
else
{
if
(!
domain
.
getRangeOpt
().
isPresent
())
{
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_ELEMENT_INPUT_DOMAIN
+
" Invalid layer input domain: Range is missing. "
,
layer
.
getSourcePosition
());
}
else
{
ASTRange
range
=
domain
.
getRangeOpt
().
get
();
if
(!
range
.
getMin
().
getNumber
().
isPresent
()
||
!
range
.
getMax
().
getNumber
().
isPresent
())
{
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_ELEMENT_INPUT_DOMAIN
+
" Invalid layer input domain: Minimum or maximum is missing. "
,
layer
.
getSourcePosition
());
}
else
{
double
min
=
range
.
getMin
().
getNumber
().
get
();
double
max
=
range
.
getMax
().
getNumber
().
get
();
// Check if minimum >= 0
if
(
min
<
0
)
{
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_ELEMENT_INPUT_DOMAIN
+
" Invalid layer input domain: Minimum needs to be bigger than 0. "
,
layer
.
getSourcePosition
());
}
int
size
=
layer
.
getIntValue
(
AllPredefinedLayers
.
SIZE_NAME
).
get
();
// Check if maximum < size
if
(
max
>=
size
)
{
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_ELEMENT_INPUT_DOMAIN
+
" Invalid layer input domain: "
+
"Maximum needs to be smaller than size "
+
size
+
". "
,
layer
.
getSourcePosition
());
}
}
}
}
}
public
static
OneHot
create
(){
OneHot
declaration
=
new
OneHot
();
List
<
VariableSymbol
>
parameters
=
new
ArrayList
<>(
Arrays
.
asList
(
new
VariableSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
ONE_HOT_
SIZE_NAME
)
.
name
(
AllPredefinedLayers
.
SIZE_NAME
)
.
constraints
(
Constraints
.
POSITIVE
,
Constraints
.
INTEGER
)
.
defaultValue
(
channels
)
.
build
()));
...
...
src/test/resources/architectures/Alexnet.cnna
View file @
a5b546e2
...
...
@@ -55,6 +55,5 @@ architecture Alexnet(img_height=224, img_width=224, img_channels=3, classes=10){
Concatenate() ->
FullyConnected(units=10) ->
Softmax() ->
OneHot() ->
predictions;
}
\ No newline at end of file
src/test/resources/valid_tests/Alexnet_alt_OneHotOutput.cnna
View file @
a5b546e2
...
...
@@ -49,6 +49,5 @@ architecture Alexnet_alt_OneHotOutput(img_height=224, img_width=224, img_channel
Dropout() ->
FullyConnected(units=classes) ->
Softmax() ->
OneHot(size=classes) ->
predictions;
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment