Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
C
CNNArchLang
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
1
Issues
1
List
Boards
Labels
Service Desk
Milestones
Iterations
Merge Requests
0
Merge Requests
0
Requirements
Requirements
List
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Test Cases
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Operations
Operations
Incidents
Environments
Packages & Registries
Packages & Registries
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issue
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
monticore
EmbeddedMontiArc
languages
CNNArchLang
Commits
9dd5761f
Commit
9dd5761f
authored
Apr 05, 2020
by
Julian Johannes Steinsberger-Dührßen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added generation of replay subnets for ReplayMemory.
parent
de6687e2
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
214 additions
and
19 deletions
+214
-19
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchitectureElementSymbol.java
...ticar/cnnarch/_symboltable/ArchitectureElementSymbol.java
+1
-0
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchitectureSymbol.java
...ang/monticar/cnnarch/_symboltable/ArchitectureSymbol.java
+22
-1
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/CompositeElementSymbol.java
...monticar/cnnarch/_symboltable/CompositeElementSymbol.java
+3
-2
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/Constraints.java
...icore/lang/monticar/cnnarch/_symboltable/Constraints.java
+32
-1
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/NetworkInstructionSymbol.java
...nticar/cnnarch/_symboltable/NetworkInstructionSymbol.java
+3
-1
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/SerialCompositeElementSymbol.java
...ar/cnnarch/_symboltable/SerialCompositeElementSymbol.java
+28
-0
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java
...lang/monticar/cnnarch/predefined/AllPredefinedLayers.java
+14
-4
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Memory.java
...de/monticore/lang/monticar/cnnarch/predefined/Memory.java
+34
-10
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/ReplayMemory.java
...ticore/lang/monticar/cnnarch/predefined/ReplayMemory.java
+77
-0
No files found.
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchitectureElementSymbol.java
View file @
9dd5761f
...
...
@@ -157,6 +157,7 @@ public abstract class ArchitectureElementSymbol extends ResolvableSymbol {
else
{
return
Optional
.
empty
();
}
}
/**
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchitectureSymbol.java
View file @
9dd5761f
...
...
@@ -214,5 +214,26 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
return
copy
;
}
public
void
processForReplayMemory
(){
for
(
NetworkInstructionSymbol
networkInstruction
:
networkInstructions
){
List
<
ArchitectureElementSymbol
>
elements
=
networkInstruction
.
getBody
().
getElements
();
List
<
ArchitectureElementSymbol
>
elementsNew
=
new
ArrayList
<>();
List
<
List
<
ArchitectureElementSymbol
>>
replaySubNetworks
=
new
ArrayList
<>(
new
ArrayList
<>());
List
<
ArchitectureElementSymbol
>
currentReplaySubNetworkElements
=
new
ArrayList
<>();
for
(
ArchitectureElementSymbol
element
:
elements
){
if
(
element
.
getName
().
equals
(
"ReplayMemory"
))
{
if
(!
currentReplaySubNetworkElements
.
isEmpty
()){
replaySubNetworks
.
add
(
currentReplaySubNetworkElements
);
}
currentReplaySubNetworkElements
=
new
ArrayList
<>();
}
currentReplaySubNetworkElements
.
add
(
element
);
}
if
(!
currentReplaySubNetworkElements
.
isEmpty
()
&&
!
replaySubNetworks
.
isEmpty
()){
replaySubNetworks
.
add
(
currentReplaySubNetworkElements
);
}
networkInstruction
.
getBody
().
setReplaySubNetworks
(
replaySubNetworks
);
}
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/CompositeElementSymbol.java
View file @
9dd5761f
...
...
@@ -17,17 +17,18 @@ public abstract class CompositeElementSymbol extends ArchitectureElementSymbol {
protected
List
<
ArchitectureElementSymbol
>
elements
=
new
ArrayList
<>();
public
CompositeElementSymbol
()
{
super
(
""
);
setResolvedThis
(
this
);
}
abstract
protected
void
setElements
(
List
<
ArchitectureElementSymbol
>
elements
);
public
List
<
ArchitectureElementSymbol
>
getElements
()
{
return
elements
;
}
abstract
protected
void
setElements
(
List
<
ArchitectureElementSymbol
>
elements
);
@Override
public
boolean
isAtomic
()
{
return
getElements
().
isEmpty
();
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/Constraints.java
View file @
9dd5761f
...
...
@@ -68,6 +68,16 @@ public enum Constraints {
return
"a tuple of integers"
;
}
},
INTEGER_OR_INTEGER_TUPLE
{
@Override
public
boolean
isValid
(
ArchSimpleExpressionSymbol
exp
)
{
return
exp
.
isInt
().
get
()
||
exp
.
isIntTuple
().
get
();
}
@Override
public
String
msgString
()
{
return
"an integer or tuple of integers"
;
}
},
POSITIVE
{
@Override
public
boolean
isValid
(
ArchSimpleExpressionSymbol
exp
)
{
...
...
@@ -90,6 +100,28 @@ public enum Constraints {
return
"a positive number"
;
}
},
POSITIVE_OR_MINUS_ONE
{
@Override
public
boolean
isValid
(
ArchSimpleExpressionSymbol
exp
)
{
if
(
exp
.
getDoubleValue
().
isPresent
()){
return
exp
.
getDoubleValue
().
get
()
>
0
||
exp
.
getDoubleValue
().
get
()
==
-
1
;
}
else
if
(
exp
.
getDoubleTupleValues
().
isPresent
()){
boolean
isPositive
=
true
;
for
(
double
value
:
exp
.
getDoubleTupleValues
().
get
()){
if
(
value
<
-
1
||
value
==
0
){
isPositive
=
false
;
}
}
return
isPositive
;
}
return
false
;
}
@Override
public
String
msgString
()
{
return
"a positive number"
;
}
},
NON_NEGATIVE
{
@Override
public
boolean
isValid
(
ArchSimpleExpressionSymbol
exp
)
{
...
...
@@ -207,7 +239,6 @@ public enum Constraints {
}
return
false
;
}
@Override
protected
String
msgString
()
{
return
AllPredefinedLayers
.
MEMORY_ACTIVATION_LINEAR
+
" or "
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/NetworkInstructionSymbol.java
View file @
9dd5761f
...
...
@@ -9,6 +9,8 @@ package de.monticore.lang.monticar.cnnarch._symboltable;
import
de.monticore.symboltable.SymbolKind
;
import
java.util.*
;
public
abstract
class
NetworkInstructionSymbol
extends
ResolvableSymbol
{
private
SerialCompositeElementSymbol
body
;
...
...
@@ -16,7 +18,7 @@ public abstract class NetworkInstructionSymbol extends ResolvableSymbol {
protected
NetworkInstructionSymbol
(
String
name
,
SymbolKind
kind
)
{
super
(
name
,
kind
);
}
public
SerialCompositeElementSymbol
getBody
()
{
return
body
;
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/SerialCompositeElementSymbol.java
View file @
9dd5761f
...
...
@@ -12,6 +12,8 @@ import java.util.*;
public
class
SerialCompositeElementSymbol
extends
CompositeElementSymbol
{
protected
List
<
List
<
ArchitectureElementSymbol
>>
replaySubNetworks
=
new
ArrayList
<>(
new
ArrayList
<>());
protected
void
setElements
(
List
<
ArchitectureElementSymbol
>
elements
)
{
ArchitectureElementSymbol
previous
=
null
;
for
(
ArchitectureElementSymbol
current
:
elements
){
...
...
@@ -32,6 +34,32 @@ public class SerialCompositeElementSymbol extends CompositeElementSymbol {
this
.
elements
=
elements
;
}
protected
void
setReplaySubNetworks
(
List
<
List
<
ArchitectureElementSymbol
>>
replaySubNetworks
){
for
(
List
<
ArchitectureElementSymbol
>
subElements:
replaySubNetworks
){
ArchitectureElementSymbol
previous
=
null
;
for
(
ArchitectureElementSymbol
current
:
subElements
){
if
(
previous
!=
null
){
current
.
setInputElement
(
previous
);
previous
.
setOutputElement
(
current
);
}
else
{
if
(
getInputElement
().
isPresent
()){
current
.
setInputElement
(
getInputElement
().
get
());
}
if
(
getOutputElement
().
isPresent
()){
current
.
setOutputElement
(
getOutputElement
().
get
());
}
}
previous
=
current
;
}
}
this
.
replaySubNetworks
=
replaySubNetworks
;
}
public
List
<
List
<
ArchitectureElementSymbol
>>
getReplaySubNetworks
()
{
return
replaySubNetworks
;
}
@Override
public
void
setInputElement
(
ArchitectureElementSymbol
inputElement
)
{
super
.
setInputElement
(
inputElement
);
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java
View file @
9dd5761f
...
...
@@ -54,6 +54,7 @@ public class AllPredefinedLayers {
public
static
final
String
BROADCAST_ADD_NAME
=
"BroadcastAdd"
;
public
static
final
String
RESHAPE_NAME
=
"Reshape"
;
public
static
final
String
MEMORY_NAME
=
"Memory"
;
public
static
final
String
REPLAY_MEMORY_NAME
=
"ReplayMemory"
;
//predefined argument names
public
static
final
String
KERNEL_NAME
=
"kernel"
;
...
...
@@ -88,13 +89,21 @@ public class AllPredefinedLayers {
public
static
final
String
BEAMSEARCH_WIDTH_NAME
=
"width"
;
public
static
final
String
SHAPE_NAME
=
"shape"
;
public
static
final
String
RNN_DROPOUT_NAME
=
"dropout"
;
//parameters for memory layer
s
//parameters for memory layer
public
static
final
String
SUB_KEY_SIZE_NAME
=
"subKeySize"
;
public
static
final
String
QUERY_SIZE_NAME
=
"querySize"
;
public
static
final
String
ACT_QUERY_NAME
=
"actQuery
"
;
public
static
final
String
QUERY_ACT_NAME
=
"queryAct
"
;
public
static
final
String
K_NAME
=
"k"
;
public
static
final
String
NUM_HEADS_NAME
=
"numHeads"
;
public
static
final
String
VALUE_SHAPE_NAME
=
"valueShape"
;
//parameters for replay memory layer
public
static
final
String
REPLAY_INTERVAL_NAME
=
"replayInterval"
;
public
static
final
String
REPLAY_BATCH_SIZE_NAME
=
"replayBatchSize"
;
public
static
final
String
REPLAY_STEPS_NAME
=
"replaySteps"
;
public
static
final
String
REPLAY_GRADIENT_STEPS_NAME
=
"replayGradientSteps"
;
public
static
final
String
STORE_PROB_NAME
=
"storeProb"
;
public
static
final
String
MAX_STORED_SAMPLES_NAME
=
"maxStoredSamples"
;
//possible String values
public
static
final
String
PADDING_VALID
=
"valid"
;
public
static
final
String
PADDING_SAME
=
"same"
;
...
...
@@ -146,7 +155,8 @@ public class AllPredefinedLayers {
SwapAxes
.
create
(),
BroadcastAdd
.
create
(),
Reshape
.
create
(),
Memory
.
create
());
Memory
.
create
(),
ReplayMemory
.
create
());
}
public
static
List
<
UnrollDeclarationSymbol
>
createUnrollList
(){
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Memory.java
View file @
9dd5761f
...
...
@@ -14,6 +14,7 @@ import java.util.ArrayList;
import
java.util.Arrays
;
import
java.util.Collections
;
import
java.util.List
;
import
java.util.Optional
;
public
class
Memory
extends
PredefinedLayerDeclaration
{
...
...
@@ -24,14 +25,31 @@ public class Memory extends PredefinedLayerDeclaration {
@Override
public
List
<
ArchTypeSymbol
>
computeOutputTypes
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
int
querySize
=
layer
.
getIntValue
(
AllPredefinedLayers
.
QUERY_SIZE_NAME
).
get
();
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
1
)
.
height
(
querySize
)
.
width
(
1
)
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
Optional
<
Integer
>
optValue
=
layer
.
getIntValue
(
AllPredefinedLayers
.
QUERY_SIZE_NAME
);
if
(
optValue
.
isPresent
()){
int
querySize
=
optValue
.
get
();
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
1
)
.
height
(
querySize
)
.
width
(
1
)
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
}
else
{
Optional
<
List
<
Integer
>>
optTupleValue
=
layer
.
getIntTupleValue
(
AllPredefinedLayers
.
QUERY_SIZE_NAME
);
List
<
Integer
>
list
=
new
ArrayList
<>();
for
(
Object
value
:
optTupleValue
.
get
())
{
list
.
add
((
Integer
)
value
);
}
int
listLen
=
list
.
size
();
int
lastEntry
=
list
.
get
(
listLen
-
1
);
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
1
)
.
height
(
lastEntry
)
.
width
(
1
)
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
}
}
@Override
...
...
@@ -48,11 +66,11 @@ public class Memory extends PredefinedLayerDeclaration {
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
QUERY_SIZE_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
constraints
(
Constraints
.
INTEGER
_OR_INTEGER_TUPLE
,
Constraints
.
POSITIVE
)
.
defaultValue
(
512
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
ACT_QUERY
_NAME
)
.
name
(
AllPredefinedLayers
.
QUERY_ACT
_NAME
)
.
constraints
(
Constraints
.
ACTIVATION_TYPE
)
.
defaultValue
(
"linear"
)
.
build
(),
...
...
@@ -63,6 +81,12 @@ public class Memory extends PredefinedLayerDeclaration {
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
NUM_HEADS_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
defaultValue
(
1
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
VALUE_SHAPE_NAME
)
.
constraints
(
Constraints
.
INTEGER_OR_INTEGER_TUPLE
,
Constraints
.
POSITIVE_OR_MINUS_ONE
)
.
defaultValue
(-
1
)
.
build
()));
declaration
.
setParameters
(
parameters
);
return
declaration
;
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/ReplayMemory.java
0 → 100644
View file @
9dd5761f
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package
de.monticore.lang.monticar.cnnarch.predefined
;
import
de.monticore.lang.monticar.cnnarch._symboltable.*
;
import
java.util.ArrayList
;
import
java.util.Arrays
;
import
java.util.Collections
;
import
java.util.List
;
import
java.util.Optional
;
public
class
ReplayMemory
extends
PredefinedLayerDeclaration
{
private
ReplayMemory
()
{
super
(
AllPredefinedLayers
.
REPLAY_MEMORY_NAME
);
}
@Override
public
List
<
ArchTypeSymbol
>
computeOutputTypes
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
1
)
.
height
(
1
)
.
width
(
1
)
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
}
@Override
public
void
checkInput
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
errorIfInputSizeIsNotOne
(
inputTypes
,
layer
);
}
public
static
ReplayMemory
create
(){
ReplayMemory
declaration
=
new
ReplayMemory
();
List
<
ParameterSymbol
>
parameters
=
new
ArrayList
<>(
Arrays
.
asList
(
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
REPLAY_INTERVAL_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
REPLAY_BATCH_SIZE_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE_OR_MINUS_ONE
)
.
defaultValue
(-
1
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
REPLAY_STEPS_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
defaultValue
(
"linear"
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
REPLAY_GRADIENT_STEPS_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE
)
.
defaultValue
(
1
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
STORE_PROB_NAME
)
.
constraints
(
Constraints
.
NUMBER
,
Constraints
.
BETWEEN_ZERO_AND_ONE
)
.
defaultValue
(
1
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
MAX_STORED_SAMPLES_NAME
)
.
constraints
(
Constraints
.
INTEGER
,
Constraints
.
POSITIVE_OR_MINUS_ONE
)
.
defaultValue
(-
1
)
.
build
()));
declaration
.
setParameters
(
parameters
);
return
declaration
;
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment