Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
What's new
7
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Open sidebar
monticore
EmbeddedMontiArc
languages
CNNArchLang
Commits
7a0f779b
Commit
7a0f779b
authored
Jul 19, 2020
by
Julian Johannes Steinsberger-Dührßen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Change Memory Layer names; Added Load Network Layer; Multiple input support for episodic Memory
parent
04607854
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
158 additions
and
65 deletions
+158
-65
src/main/java/de/monticore/lang/monticar/cnnarch/_cocos/CNNArchCocos.java
.../monticore/lang/monticar/cnnarch/_cocos/CNNArchCocos.java
+2
-2
src/main/java/de/monticore/lang/monticar/cnnarch/_cocos/CheckEpisodicMemoryLayer.java
...ang/monticar/cnnarch/_cocos/CheckEpisodicMemoryLayer.java
+9
-9
src/main/java/de/monticore/lang/monticar/cnnarch/_cocos/CheckLargeMemoryLayer.java
...e/lang/monticar/cnnarch/_cocos/CheckLargeMemoryLayer.java
+6
-6
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchTypeSymbol.java
...re/lang/monticar/cnnarch/_symboltable/ArchTypeSymbol.java
+1
-2
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchitectureSymbol.java
...ang/monticar/cnnarch/_symboltable/ArchitectureSymbol.java
+11
-11
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/SerialCompositeElementSymbol.java
...ar/cnnarch/_symboltable/SerialCompositeElementSymbol.java
+6
-6
src/main/java/de/monticore/lang/monticar/cnnarch/helper/ErrorCodes.java
...de/monticore/lang/monticar/cnnarch/helper/ErrorCodes.java
+3
-3
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java
...lang/monticar/cnnarch/predefined/AllPredefinedLayers.java
+16
-8
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/EpisodicMemory.java
...core/lang/monticar/cnnarch/predefined/EpisodicMemory.java
+26
-13
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/LargeMemory.java
...nticore/lang/monticar/cnnarch/predefined/LargeMemory.java
+5
-5
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/LoadNetwork.java
...nticore/lang/monticar/cnnarch/predefined/LoadNetwork.java
+73
-0
No files found.
src/main/java/de/monticore/lang/monticar/cnnarch/_cocos/CNNArchCocos.java
View file @
7a0f779b
...
...
@@ -60,8 +60,8 @@ public class CNNArchCocos {
.
addCoCo
(
new
CheckLayerVariableDeclarationLayerType
())
.
addCoCo
(
new
CheckLayerVariableDeclarationIsUsed
())
.
addCoCo
(
new
CheckConstants
())
.
addCoCo
(
new
CheckMemoryLayer
())
.
addCoCo
(
new
Check
Replay
MemoryLayer
())
.
addCoCo
(
new
Check
Large
MemoryLayer
())
.
addCoCo
(
new
Check
Episodic
MemoryLayer
())
.
addCoCo
(
new
CheckUnrollInputsOutputsTooMany
());
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/_cocos/Check
Replay
MemoryLayer.java
→
src/main/java/de/monticore/lang/monticar/cnnarch/_cocos/Check
Episodic
MemoryLayer.java
View file @
7a0f779b
...
...
@@ -21,7 +21,7 @@ import java.util.Optional;
import
java.util.List
;
import
java.io.File
;
public
class
Check
Replay
MemoryLayer
extends
CNNArchSymbolCoCo
{
public
class
Check
Episodic
MemoryLayer
extends
CNNArchSymbolCoCo
{
@Override
public
void
check
(
StreamInstructionSymbol
stream
)
{
...
...
@@ -29,22 +29,22 @@ public class CheckReplayMemoryLayer extends CNNArchSymbolCoCo {
for
(
ArchitectureElementSymbol
element
:
elements
)
{
if
(
element
instanceof
ParallelCompositeElementSymbol
)
{
checkFor
Replay
Memory
((
ParallelCompositeElementSymbol
)
element
);
}
else
if
(
element
.
getName
().
equals
(
"
Replay
Memory"
))
{
checkFor
Episodic
Memory
((
ParallelCompositeElementSymbol
)
element
);
}
else
if
(
element
.
getName
().
equals
(
"
Episodic
Memory"
))
{
checkParameters
((
LayerSymbol
)
element
);
}
}
}
public
void
checkFor
Replay
Memory
(
ParallelCompositeElementSymbol
parallelElement
)
{
public
void
checkFor
Episodic
Memory
(
ParallelCompositeElementSymbol
parallelElement
)
{
for
(
ArchitectureElementSymbol
subStream
:
parallelElement
.
getElements
())
{
if
(
subStream
instanceof
SerialCompositeElementSymbol
)
{
//should always be the case
for
(
ArchitectureElementSymbol
element
:
((
SerialCompositeElementSymbol
)
subStream
).
getElements
())
{
if
(
element
instanceof
ParallelCompositeElementSymbol
)
{
checkFor
Replay
Memory
((
ParallelCompositeElementSymbol
)
element
);
}
else
if
(
element
.
getName
().
equals
(
"
Replay
Memory"
))
{
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_
R
EP
LAY
_MEMORY_LAYER_PLACEMENT
+
" Invalid placement of
Replay
Memory layer. It can't be placed inside a Prallalel execution block."
,
checkFor
Episodic
Memory
((
ParallelCompositeElementSymbol
)
element
);
}
else
if
(
element
.
getName
().
equals
(
"
Episodic
Memory"
))
{
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_EP
ISODIC
_MEMORY_LAYER_PLACEMENT
+
" Invalid placement of
Episodic
Memory layer. It can't be placed inside a Prallalel execution block."
,
element
.
getSourcePosition
());
}
}
...
...
@@ -74,7 +74,7 @@ public class CheckReplayMemoryLayer extends CNNArchSymbolCoCo {
}
}
}
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_
R
EP
LAY
_QUERY_NET_PATH_OR_PREFIX
+
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_EP
ISODIC
_QUERY_NET_PATH_OR_PREFIX
+
" For the concatination of queryNetDir and queryNetPrefix exists no file wich path has this as prefix."
,
layer
.
getSourcePosition
());
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/_cocos/CheckMemoryLayer.java
→
src/main/java/de/monticore/lang/monticar/cnnarch/_cocos/Check
Large
MemoryLayer.java
View file @
7a0f779b
...
...
@@ -17,16 +17,16 @@ import java.util.Optional;
import
java.util.List
;
public
class
CheckMemoryLayer
extends
CNNArchSymbolCoCo
{
public
class
Check
Large
MemoryLayer
extends
CNNArchSymbolCoCo
{
@Override
public
void
check
(
ArchitectureElementSymbol
sym
)
{
if
(
sym
instanceof
LayerSymbol
&&
sym
.
getName
().
equals
(
"Memory"
))
{
checkMemoryLayer
((
LayerSymbol
)
sym
);
if
(
sym
instanceof
LayerSymbol
&&
sym
.
getName
().
equals
(
"
Large
Memory"
))
{
check
Large
MemoryLayer
((
LayerSymbol
)
sym
);
}
}
public
void
checkMemoryLayer
(
LayerSymbol
layer
)
{
public
void
check
Large
MemoryLayer
(
LayerSymbol
layer
)
{
List
<
ArgumentSymbol
>
arguments
=
layer
.
getArguments
();
Integer
subKeySize
=
new
Integer
(
0
);
Integer
k
=
new
Integer
(
0
);
...
...
@@ -40,8 +40,8 @@ public class CheckMemoryLayer extends CNNArchSymbolCoCo {
}
if
(
subKeySize
<
k
)
{
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_MEMORY_LAYER_PARAMETERS
+
" Invalid Memory layer Parameter values, subKeySize has to be greater or equal to k. "
,
Log
.
error
(
"0"
+
ErrorCodes
.
INVALID_
LARGE_
MEMORY_LAYER_PARAMETERS
+
" Invalid
Large
Memory layer Parameter values, subKeySize has to be greater or equal to k. "
,
layer
.
getSourcePosition
());
}
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchTypeSymbol.java
View file @
7a0f779b
...
...
@@ -30,7 +30,6 @@ public class ArchTypeSymbol extends CommonSymbol {
private
int
widthIndex
=
-
1
;
private
List
<
ArchSimpleExpressionSymbol
>
dimensions
=
new
ArrayList
<>();
public
ArchTypeSymbol
()
{
super
(
""
,
KIND
);
ASTElementType
elementType
=
new
ASTElementType
();
...
...
@@ -146,7 +145,7 @@ public class ArchTypeSymbol extends CommonSymbol {
}
return
dimensionList
;
}
public
Set
<
ParameterSymbol
>
resolve
()
{
if
(!
isResolved
()){
if
(
isResolvable
()){
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchitectureSymbol.java
View file @
7a0f779b
...
...
@@ -214,15 +214,15 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
return
copy
;
}
public
void
processForReplayMemory
(){
public
void
processFor
Episodic
ReplayMemory
(){
for
(
NetworkInstructionSymbol
networkInstruction
:
networkInstructions
){
List
<
ArchitectureElementSymbol
>
elements
=
networkInstruction
.
getBody
().
getElements
();
List
<
ArchitectureElementSymbol
>
elementsNew
=
new
ArrayList
<>();
List
<
List
<
ArchitectureElementSymbol
>>
r
ep
lay
SubNetworks
=
new
ArrayList
<>(
new
ArrayList
<>());
List
<
ArchitectureElementSymbol
>
current
Replay
SubNetworkElements
=
new
ArrayList
<>();
List
<
List
<
ArchitectureElementSymbol
>>
ep
isodic
SubNetworks
=
new
ArrayList
<>(
new
ArrayList
<>());
List
<
ArchitectureElementSymbol
>
current
Episodic
SubNetworkElements
=
new
ArrayList
<>();
for
(
ArchitectureElementSymbol
element
:
elements
){
if
(
AllPredefinedLayers
.
REPLAY_LAYER_NAMES
.
contains
(
element
.
getName
()))
{
if
(
AllPredefinedLayers
.
EPISODIC_
REPLAY_LAYER_NAMES
.
contains
(
element
.
getName
()))
{
boolean
use_replay
=
false
;
boolean
use_local_adaption
=
false
;
...
...
@@ -251,18 +251,18 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
}
if
(
use_replay
||
use_local_adaption
){
if
(!
current
Replay
SubNetworkElements
.
isEmpty
()){
r
ep
lay
SubNetworks
.
add
(
current
Replay
SubNetworkElements
);
if
(!
current
Episodic
SubNetworkElements
.
isEmpty
()){
ep
isodic
SubNetworks
.
add
(
current
Episodic
SubNetworkElements
);
}
current
Replay
SubNetworkElements
=
new
ArrayList
<>();
current
Episodic
SubNetworkElements
=
new
ArrayList
<>();
}
}
current
Replay
SubNetworkElements
.
add
(
element
);
current
Episodic
SubNetworkElements
.
add
(
element
);
}
if
(!
current
Replay
SubNetworkElements
.
isEmpty
()
&&
!
r
ep
lay
SubNetworks
.
isEmpty
()){
r
ep
lay
SubNetworks
.
add
(
current
Replay
SubNetworkElements
);
if
(!
current
Episodic
SubNetworkElements
.
isEmpty
()
&&
!
ep
isodic
SubNetworks
.
isEmpty
()){
ep
isodic
SubNetworks
.
add
(
current
Episodic
SubNetworkElements
);
}
networkInstruction
.
getBody
().
set
Replay
SubNetworks
(
r
ep
lay
SubNetworks
);
networkInstruction
.
getBody
().
set
Episodic
SubNetworks
(
ep
isodic
SubNetworks
);
}
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/SerialCompositeElementSymbol.java
View file @
7a0f779b
...
...
@@ -12,7 +12,7 @@ import java.util.*;
public
class
SerialCompositeElementSymbol
extends
CompositeElementSymbol
{
protected
List
<
List
<
ArchitectureElementSymbol
>>
r
ep
lay
SubNetworks
=
new
ArrayList
<>(
new
ArrayList
<>());
protected
List
<
List
<
ArchitectureElementSymbol
>>
ep
isodic
SubNetworks
=
new
ArrayList
<>(
new
ArrayList
<>());
protected
void
setElements
(
List
<
ArchitectureElementSymbol
>
elements
)
{
ArchitectureElementSymbol
previous
=
null
;
...
...
@@ -34,8 +34,8 @@ public class SerialCompositeElementSymbol extends CompositeElementSymbol {
this
.
elements
=
elements
;
}
protected
void
set
Replay
SubNetworks
(
List
<
List
<
ArchitectureElementSymbol
>>
r
ep
lay
SubNetworks
){
for
(
List
<
ArchitectureElementSymbol
>
subElements:
r
ep
lay
SubNetworks
){
protected
void
set
Episodic
SubNetworks
(
List
<
List
<
ArchitectureElementSymbol
>>
ep
isodic
SubNetworks
){
for
(
List
<
ArchitectureElementSymbol
>
subElements:
ep
isodic
SubNetworks
){
ArchitectureElementSymbol
previous
=
null
;
for
(
ArchitectureElementSymbol
current
:
subElements
){
if
(
previous
!=
null
){
...
...
@@ -53,11 +53,11 @@ public class SerialCompositeElementSymbol extends CompositeElementSymbol {
previous
=
current
;
}
}
this
.
r
ep
lay
SubNetworks
=
r
ep
lay
SubNetworks
;
this
.
ep
isodic
SubNetworks
=
ep
isodic
SubNetworks
;
}
public
List
<
List
<
ArchitectureElementSymbol
>>
get
Replay
SubNetworks
()
{
return
r
ep
lay
SubNetworks
;
public
List
<
List
<
ArchitectureElementSymbol
>>
get
Episodic
SubNetworks
()
{
return
ep
isodic
SubNetworks
;
}
@Override
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/helper/ErrorCodes.java
View file @
7a0f779b
...
...
@@ -37,9 +37,9 @@ public class ErrorCodes {
public
static
final
String
ILLEGAL_LAYER_USE
=
"x04845"
;
public
static
final
String
UNUSED_LAYER
=
"x04847"
;
public
static
final
String
INVALID_CONSTANT
=
"x04856"
;
public
static
final
String
INVALID_MEMORY_LAYER_PARAMETERS
=
"x04866"
;
public
static
final
String
INVALID_
R
EP
LAY
_MEMORY_LAYER_PLACEMENT
=
"x04876"
;
public
static
final
String
INVALID_
R
EP
LAY
_QUERY_NET_PATH_OR_PREFIX
=
"x04877"
;
public
static
final
String
INVALID_
LARGE_
MEMORY_LAYER_PARAMETERS
=
"x04866"
;
public
static
final
String
INVALID_EP
ISODIC
_MEMORY_LAYER_PLACEMENT
=
"x04876"
;
public
static
final
String
INVALID_EP
ISODIC
_QUERY_NET_PATH_OR_PREFIX
=
"x04877"
;
public
static
final
String
OUTPUT_WRITTEN_TO_MULTIPLE_TIMES
=
"x04836"
;
public
static
final
String
UNROLL_INPUTS_TOO_MANY
=
"x02384"
;
public
static
final
String
UNROLL_OUTPUTS_TOO_MANY
=
"x02385"
;
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java
View file @
7a0f779b
...
...
@@ -55,11 +55,12 @@ public class AllPredefinedLayers {
public
static
final
String
BROADCAST_ADD_NAME
=
"BroadcastAdd"
;
public
static
final
String
RESHAPE_NAME
=
"Reshape"
;
public
static
final
String
DOT_PRODUCT_SELF_ATTENTION_NAME
=
"DotProductSelfAttention"
;
public
static
final
String
LOAD_NETWORK_NAME
=
"LoadNetwork"
;
//replay layers
public
static
final
String
MEMORY_NAME
=
"Memory"
;
public
static
final
String
R
EP
LAY
_MEMORY_NAME
=
"
Replay
Memory"
;
public
static
final
List
<
String
>
REPLAY_LAYER_NAMES
=
new
ArrayList
<
String
>(
Arrays
.
asList
(
R
EP
LAY
_MEMORY_NAME
));
public
static
final
String
LARGE_
MEMORY_NAME
=
"
Large
Memory"
;
public
static
final
String
EP
ISODIC
_MEMORY_NAME
=
"
Episodic
Memory"
;
public
static
final
List
<
String
>
EPISODIC_
REPLAY_LAYER_NAMES
=
new
ArrayList
<
String
>(
Arrays
.
asList
(
EP
ISODIC
_MEMORY_NAME
));
//predefined argument names
...
...
@@ -96,6 +97,11 @@ public class AllPredefinedLayers {
public
static
final
String
SHAPE_NAME
=
"shape"
;
public
static
final
String
RNN_DROPOUT_NAME
=
"dropout"
;
//parameters LoadNetwork layer
public
static
final
String
NETWORK_DIR_NAME
=
"networkDir"
;
public
static
final
String
NETWORK_PREFIX_NAME
=
"networkPrefix"
;
public
static
final
String
NUM_INPUTS_NAME
=
"numInputs"
;
public
static
final
String
OUTPUT_SHAPE_NAME
=
"outputShape"
;
//parameters DotProductSelfAttention
public
static
final
String
SCALE_FACTOR_NAME
=
"scaleFactor"
;
...
...
@@ -103,7 +109,7 @@ public class AllPredefinedLayers {
public
static
final
String
DIM_VALUES_NAME
=
"dimValues"
;
public
static
final
String
USE_PROJ_BIAS_NAME
=
"useProjBias"
;
//shared parameters replay layers
//shared parameters
episodic
replay layers
public
static
final
String
USE_REPLAY_NAME
=
"useReplay"
;
public
static
final
String
REPLAY_INTERVAL_NAME
=
"replayInterval"
;
public
static
final
String
REPLAY_BATCH_SIZE_NAME
=
"replayBatchSize"
;
...
...
@@ -115,7 +121,7 @@ public class AllPredefinedLayers {
public
static
final
String
LOCAL_ADAPTION_GRADIENT_STEPS_NAME
=
"localAdaptionGradientSteps"
;
public
static
final
String
LOCAL_ADAPTION_MEMORY_STORE_DIST_MEASURE_NAME
=
"localAdaptionMemoryStoreDistMeasure"
;
//parameters for memory layer
//parameters for
episodic
memory layer
public
static
final
String
SUB_KEY_SIZE_NAME
=
"subKeySize"
;
public
static
final
String
QUERY_SIZE_NAME
=
"querySize"
;
public
static
final
String
QUERY_ACT_NAME
=
"queryAct"
;
...
...
@@ -124,11 +130,12 @@ public class AllPredefinedLayers {
public
static
final
String
STORE_DIST_MEASURE_NAME
=
"storeDistMeasure"
;
public
static
final
String
VALUES_DIM_NAME
=
"valuesDim"
;
//parameters for
r
ep
lay
memory layer
//parameters for ep
isodic
memory layer
public
static
final
String
MAX_STORED_SAMPLES_NAME
=
"maxStoredSamples"
;
public
static
final
String
REPLAY_MEMORY_STORE_PROB_NAME
=
"replayMemoryStoreProb"
;
public
static
final
String
QUERY_NET_DIR_NAME
=
"queryNetDir"
;
public
static
final
String
QUERY_NET_PREFIX_NAME
=
"queryNetPrefix"
;
public
static
final
String
QUERY_NET_NUM_INPUTS_NAME
=
"queryNetNumInputs"
;
//possible String values
public
static
final
String
PADDING_VALID
=
"valid"
;
...
...
@@ -184,9 +191,10 @@ public class AllPredefinedLayers {
SwapAxes
.
create
(),
BroadcastAdd
.
create
(),
Reshape
.
create
(),
LoadNetwork
.
create
(),
DotProductSelfAttention
.
create
(),
Memory
.
create
(),
Replay
Memory
.
create
());
Large
Memory
.
create
(),
Episodic
Memory
.
create
());
}
public
static
List
<
UnrollDeclarationSymbol
>
createUnrollList
(){
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/
Replay
Memory.java
→
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/
Episodic
Memory.java
View file @
7a0f779b
...
...
@@ -16,31 +16,39 @@ import java.util.Collections;
import
java.util.List
;
import
java.util.Optional
;
public
class
Replay
Memory
extends
PredefinedLayerDeclaration
{
public
class
Episodic
Memory
extends
PredefinedLayerDeclaration
{
private
Replay
Memory
()
{
super
(
AllPredefinedLayers
.
R
EP
LAY
_MEMORY_NAME
);
private
Episodic
Memory
()
{
super
(
AllPredefinedLayers
.
EP
ISODIC
_MEMORY_NAME
);
}
@Override
public
List
<
ArchTypeSymbol
>
computeOutputTypes
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
1
)
.
height
(
1
)
.
width
(
1
)
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
List
<
ArchTypeSymbol
>
outputShapes
=
new
ArrayList
<>(
layer
.
getInputTypes
().
size
());
for
(
int
i
=
0
;
i
<
layer
.
getInputTypes
().
size
();
i
++)
{
ArchTypeSymbol
inputShape
=
layer
.
getInputTypes
().
get
(
i
);
int
inputHeight
=
inputShape
.
getHeight
();
int
inputWidth
=
inputShape
.
getWidth
();
int
inputChannels
=
inputShape
.
getChannels
();
outputShapes
.
add
(
new
ArchTypeSymbol
.
Builder
()
.
height
(
inputHeight
)
.
width
(
inputWidth
)
.
channels
(
inputChannels
)
.
elementType
(
layer
.
getInputTypes
().
get
(
i
).
getDomain
())
.
build
());
}
return
outputShapes
;
}
@Override
public
void
checkInput
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
errorIfInput
SizeIsNotOne
(
inputTypes
,
layer
);
errorIfInput
IsEmpty
(
inputTypes
,
layer
);
}
public
static
Replay
Memory
create
(){
Replay
Memory
declaration
=
new
Replay
Memory
();
public
static
Episodic
Memory
create
(){
Episodic
Memory
declaration
=
new
Episodic
Memory
();
List
<
ParameterSymbol
>
parameters
=
new
ArrayList
<>(
Arrays
.
asList
(
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
USE_REPLAY_NAME
)
...
...
@@ -104,6 +112,11 @@ public class ReplayMemory extends PredefinedLayerDeclaration {
.
name
(
AllPredefinedLayers
.
QUERY_NET_PREFIX_NAME
)
.
constraints
(
Constraints
.
STRING
)
.
defaultValue
(-
1
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
QUERY_NET_NUM_INPUTS_NAME
)
.
constraints
(
Constraints
.
INTEGER
)
.
defaultValue
(
1
)
.
build
()));
declaration
.
setParameters
(
parameters
);
return
declaration
;
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Memory.java
→
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/
Large
Memory.java
View file @
7a0f779b
...
...
@@ -16,10 +16,10 @@ import java.util.Collections;
import
java.util.List
;
import
java.util.Optional
;
public
class
Memory
extends
PredefinedLayerDeclaration
{
public
class
Large
Memory
extends
PredefinedLayerDeclaration
{
private
Memory
()
{
super
(
AllPredefinedLayers
.
MEMORY_NAME
);
private
Large
Memory
()
{
super
(
AllPredefinedLayers
.
LARGE_
MEMORY_NAME
);
}
@Override
...
...
@@ -57,8 +57,8 @@ public class Memory extends PredefinedLayerDeclaration {
errorIfInputSizeIsNotOne
(
inputTypes
,
layer
);
}
public
static
Memory
create
(){
Memory
declaration
=
new
Memory
();
public
static
Large
Memory
create
(){
Large
Memory
declaration
=
new
Large
Memory
();
List
<
ParameterSymbol
>
parameters
=
new
ArrayList
<>(
Arrays
.
asList
(
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
STORE_DIST_MEASURE_NAME
)
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/predefined/LoadNetwork.java
0 → 100644
View file @
7a0f779b
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package
de.monticore.lang.monticar.cnnarch.predefined
;
import
de.monticore.lang.monticar.cnnarch._symboltable.*
;
import
java.util.ArrayList
;
import
java.util.Arrays
;
import
java.util.Collections
;
import
java.util.List
;
import
java.util.Optional
;
public
class
LoadNetwork
extends
PredefinedLayerDeclaration
{
private
LoadNetwork
()
{
super
(
AllPredefinedLayers
.
LOAD_NETWORK_NAME
);
}
@Override
public
List
<
ArchTypeSymbol
>
computeOutputTypes
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
Optional
<
List
<
Integer
>>
optValue
=
layer
.
getIntTupleValue
(
AllPredefinedLayers
.
OUTPUT_SHAPE_NAME
);
List
<
Integer
>
shapeList
=
Arrays
.
asList
(
1
,
1
,
1
);
if
(
optValue
.
isPresent
())
{
List
<
Integer
>
outputShape
=
optValue
.
get
();
for
(
int
i
=
0
;
i
<
outputShape
.
size
()
&&
i
<
3
;
i
++)
{
shapeList
.
set
(
i
,
outputShape
.
get
(
i
));
}
}
return
Collections
.
singletonList
(
new
ArchTypeSymbol
.
Builder
()
.
channels
(
shapeList
.
get
(
0
))
.
height
(
shapeList
.
get
(
1
))
.
width
(
shapeList
.
get
(
2
))
.
elementType
(
"-oo"
,
"oo"
)
.
build
());
}
@Override
public
void
checkInput
(
List
<
ArchTypeSymbol
>
inputTypes
,
LayerSymbol
layer
,
VariableSymbol
.
Member
member
)
{
errorIfInputIsEmpty
(
inputTypes
,
layer
);
}
public
static
LoadNetwork
create
(){
LoadNetwork
declaration
=
new
LoadNetwork
();
List
<
ParameterSymbol
>
parameters
=
new
ArrayList
<>(
Arrays
.
asList
(
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
NETWORK_DIR_NAME
)
.
constraints
(
Constraints
.
STRING
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
NETWORK_PREFIX_NAME
)
.
constraints
(
Constraints
.
STRING
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
NUM_INPUTS_NAME
)
.
constraints
(
Constraints
.
INTEGER
)
.
build
(),
new
ParameterSymbol
.
Builder
()
.
name
(
AllPredefinedLayers
.
OUTPUT_SHAPE_NAME
)
.
constraints
(
Constraints
.
INTEGER_TUPLE
)
.
build
()));
declaration
.
setParameters
(
parameters
);
return
declaration
;
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment