Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
generators
EMADL2CPP
Commits
2102b34d
Commit
2102b34d
authored
Aug 25, 2020
by
Julian Johannes Steinsberger-Dührßen
Browse files
bug fixes, added tests for EpisodicMemory, increased version number
parent
7e9f1b0c
Changes
58
Hide whitespace changes
Inline
Side-by-side
pom.xml
View file @
2102b34d
...
...
@@ -9,19 +9,19 @@
<groupId>
de.monticore.lang.monticar
</groupId>
<artifactId>
embedded-montiarc-emadl-generator
</artifactId>
<version>
0.4.
0
</version>
<version>
0.4.
1
</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<emadl.version>
0.2.1
1
-SNAPSHOT
</emadl.version>
<CNNTrain.version>
0.3.1
1
-SNAPSHOT
</CNNTrain.version>
<cnnarch-generator.version>
0.0.
6
-SNAPSHOT
</cnnarch-generator.version>
<emadl.version>
0.2.1
2
-SNAPSHOT
</emadl.version>
<CNNTrain.version>
0.3.1
2
-SNAPSHOT
</CNNTrain.version>
<cnnarch-generator.version>
0.0.
7
-SNAPSHOT
</cnnarch-generator.version>
<cnnarch-mxnet-generator.version>
0.2.17-SNAPSHOT
</cnnarch-mxnet-generator.version>
<cnnarch-caffe2-generator.version>
0.2.14-SNAPSHOT
</cnnarch-caffe2-generator.version>
<cnnarch-gluon-generator.version>
0.2.1
1
-SNAPSHOT
</cnnarch-gluon-generator.version>
<cnnarch-gluon-generator.version>
0.2.1
2
-SNAPSHOT
</cnnarch-gluon-generator.version>
<cnnarch-tensorflow-generator.version>
0.1.0-SNAPSHOT
</cnnarch-tensorflow-generator.version>
<Common-MontiCar.version>
0.0.19-SNAPSHOT
</Common-MontiCar.version>
<embedded-montiarc-math-opt-generator>
0.1.6
</embedded-montiarc-math-opt-generator>
...
...
src/test/java/de/monticore/lang/monticar/emadl/GenerationTest.java
View file @
2102b34d
...
...
@@ -106,6 +106,13 @@ public class GenerationTest extends AbstractSymtabTest {
assertTrue
(
Log
.
getFindings
().
isEmpty
());
}
@Test
public
void
testEpisodicMemorySimpleGeneration
()
throws
IOException
,
TemplateException
{
Log
.
getFindings
().
clear
();
String
[]
args
=
{
"-m"
,
"src/test/resources/models"
,
"-r"
,
"episodicMemorySimple.Network"
,
"-b"
,
"GLUON"
,
"-f"
,
"n"
,
"-c"
,
"n"
};
EMADLGeneratorCli
.
main
(
args
);
}
@Test
public
void
testMultipleInstances
()
throws
IOException
,
TemplateException
{
try
{
...
...
@@ -183,7 +190,6 @@ public class GenerationTest extends AbstractSymtabTest {
"CNNPredictor_mnist_mnistClassifier_net.h"
,
"CNNDataLoader_mnist_mnistClassifier_net.py"
,
"CNNSupervisedTrainer_mnist_mnistClassifier_net.py"
,
"mnist_mnistClassifier_net.h"
,
"HelperA.h"
,
"CNNTranslator.h"
,
"mnist_mnistClassifier_calculateClass.h"
,
...
...
@@ -300,9 +306,6 @@ public class GenerationTest extends AbstractSymtabTest {
"CNNTrainer_defaultGAN_defaultGANConnector_predictor.py"
,
"defaultGAN_defaultGANConnector.cpp"
,
"defaultGAN_defaultGANConnector.h"
,
"defaultGAN_defaultGANConnector_predictor.h"
,
"defaultGAN_defaultGANConnector.cpp"
,
"defaultGAN_defaultGANConnector.h"
,
"defaultGAN_defaultGANConnector_predictor.h"
)
);
...
...
@@ -361,7 +364,7 @@ public class GenerationTest extends AbstractSymtabTest {
EMADLGeneratorCli
.
main
(
args
);
assertEquals
(
Log
.
getFindings
().
size
(),
1
);
assertEquals
(
Log
.
getFindings
().
get
(
0
).
toString
(),
"Tagging info for symbol was found, ignoring data_paths.txt: src/test/resources/models"
);
"Tagging info for
DataPath
symbol was found, ignoring data_paths.txt: src/test/resources/models"
);
assertTrue
(
Log
.
getErrorCount
()
==
0
);
}
...
...
src/test/java/de/monticore/lang/monticar/emadl/IntegrationGluonTest.java
View file @
2102b34d
...
...
@@ -70,6 +70,16 @@ public class IntegrationGluonTest extends IntegrationTest {
assertTrue
(
Log
.
getFindings
().
isEmpty
());
}
@Test
public
void
testEpisodicMemorySimple
()
{
Log
.
getFindings
().
clear
();
deleteHashFile
(
Paths
.
get
(
"./target/generated-sources-emadl/episodicMemorySimple/episodicMemorySimple.training_hash"
));
String
[]
args
=
{
"-m"
,
"src/test/resources/models"
,
"-r"
,
"episodicMemorySimple.Network"
,
"-b"
,
"GLUON"
};
EMADLGeneratorCli
.
main
(
args
);
}
@Test
public
void
testGluonPreprocessingWithSupervised
()
{
Log
.
getFindings
().
clear
();
...
...
src/test/resources/models/episodicMemorySimple/Network.cnnt
0 → 100644
View file @
2102b34d
/* (c) https://github.com/MontiCore/monticore */
configuration Network{
num_epoch:1
batch_size:5
normalize:false
context:cpu
load_checkpoint:false
loss:cross_entropy
optimizer:adam{
learning_rate:0.00003
weight_decay:0.01
}
}
src/test/resources/models/episodicMemorySimple/Network.emadl
0 → 100644
View file @
2102b34d
/*
(
c
)
https
://
github
.
com
/
MontiCore
/
monticore
*/
package
episodicMemorySimple
;
component
Network
{
ports
in
Z
(
0
:
oo
)^{
10
}
data
,
out
Q
(
0
:
1
)^{
33
}
softmax
;
implementation
CNN
{
data
->
EpisodicMemory
(
replayInterval
=
10
,
replayBatchSize
=
100
,
replaySteps
=
1
,
replayGradientSteps
=
1
,
replayMemoryStoreProb
=
0.5
,
localAdaptionGradientSteps
=
30
,
maxStoredSamples
=-
1
,
localAdaptionK
=
32
,
queryNetDir
=
"tag:simple"
,
queryNetPrefix
=
"simple_embedding-"
,
queryNetNumInputs
=
1
)
->
LoadNetwork
(
networkDir
=
"tag:simple"
,
networkPrefix
=
"simple_embedding-"
,
numInputs
=
1
,
outputShape
=(
1
,
768
))
->
FullyConnected
(
units
=
33
)
->
Softmax
()
->
softmax
;
}
}
src/test/resources/models/episodicMemorySimple/episodicMemorySimple.tag
0 → 100644
View file @
2102b34d
/*
(
c
)
https
://
github
.
com
/
MontiCore
/
monticore
*/
package
episodicMemorySimple
;
conforms
to
dltag
.
DataPathTagSchema
,
dltag
.
LayerPathParameterTagSchema
;
tags
episodic
{
tag
Network
with
DataPath
=
{
path
=
src
/
test
/
resources
/
training_data
/
episodicMemorySimple
,
type
=
HDF5
};
tag
Network
with
LayerPathParameter
=
{
path
=
src
/
test
/
resources
/
pretrained
/
episodicMemorySimple
,
id
=
simple
};
}
src/test/resources/pretrained/episodicMemorySimple/simple_embedding-0000.params
0 → 100644
View file @
2102b34d
File added
src/test/resources/pretrained/episodicMemorySimple/simple_embedding-symbol.json
0 → 100644
View file @
2102b34d
{
"nodes"
:
[
{
"op"
:
"null"
,
"name"
:
"data"
,
"inputs"
:
[]
},
{
"op"
:
"_copy"
,
"name"
:
"simpleembedding0_identity0"
,
"inputs"
:
[[
0
,
0
,
0
]]
}
],
"arg_nodes"
:
[
0
],
"node_row_ptr"
:
[
0
,
1
,
2
],
"heads"
:
[[
1
,
0
,
0
]],
"attrs"
:
{
"mxnet_version"
:
[
"int"
,
10501
]}
}
\ No newline at end of file
src/test/resources/target_code/gluon/CNNCreator_mnist_mnistClassifier_net.py
View file @
2102b34d
...
...
@@ -2,6 +2,8 @@ import mxnet as mx
import
logging
import
os
import
shutil
import
warnings
import
inspect
from
CNNNet_mnist_mnistClassifier_net
import
Net_0
...
...
@@ -20,6 +22,10 @@ class CNNCreator_mnist_mnistClassifier_net:
for
i
,
network
in
self
.
networks
.
items
():
lastEpoch
=
0
param_file
=
None
if
hasattr
(
network
,
'episodic_sub_nets'
):
num_episodic_sub_nets
=
len
(
network
.
episodic_sub_nets
)
lastMemEpoch
=
[
0
]
*
num_episodic_sub_nets
mem_files
=
[
None
]
*
num_episodic_sub_nets
try
:
os
.
remove
(
self
.
_model_dir_
+
self
.
_model_prefix_
+
"_"
+
str
(
i
)
+
"_newest-0000.params"
)
...
...
@@ -30,22 +36,77 @@ class CNNCreator_mnist_mnistClassifier_net:
except
OSError
:
pass
if
hasattr
(
network
,
'episodic_sub_nets'
):
try
:
os
.
remove
(
self
.
_model_dir_
+
self
.
_model_prefix_
+
"_"
+
str
(
i
)
+
'_newest_episodic_sub_net_'
+
str
(
0
)
+
"-0000.params"
)
except
OSError
:
pass
try
:
os
.
remove
(
self
.
_model_dir_
+
self
.
_model_prefix_
+
"_"
+
str
(
i
)
+
'_newest_episodic_sub_net_'
+
str
(
0
)
+
"-symbol.json"
)
except
OSError
:
pass
for
j
in
range
(
len
(
network
.
episodic_sub_nets
)):
try
:
os
.
remove
(
self
.
_model_dir_
+
self
.
_model_prefix_
+
"_"
+
str
(
i
)
+
'_newest_episodic_sub_net_'
+
str
(
j
+
1
)
+
"-0000.params"
)
except
OSError
:
pass
try
:
os
.
remove
(
self
.
_model_dir_
+
self
.
_model_prefix_
+
"_"
+
str
(
i
)
+
'_newest_episodic_sub_net_'
+
str
(
j
+
1
)
+
"-symbol.json"
)
except
OSError
:
pass
try
:
os
.
remove
(
self
.
_model_dir_
+
self
.
_model_prefix_
+
"_"
+
str
(
i
)
+
'_newest_episodic_query_net_'
+
str
(
j
+
1
)
+
"-0000.params"
)
except
OSError
:
pass
try
:
os
.
remove
(
self
.
_model_dir_
+
self
.
_model_prefix_
+
"_"
+
str
(
i
)
+
'_newest_episodic_query_net_'
+
str
(
j
+
1
)
+
"-symbol.json"
)
except
OSError
:
pass
try
:
os
.
remove
(
self
.
_model_dir_
+
self
.
_model_prefix_
+
"_"
+
str
(
i
)
+
'_newest_loss'
+
"-0000.params"
)
except
OSError
:
pass
try
:
os
.
remove
(
self
.
_model_dir_
+
self
.
_model_prefix_
+
"_"
+
str
(
i
)
+
'_newest_loss'
+
"-symbol.json"
)
except
OSError
:
pass
try
:
os
.
remove
(
self
.
_model_dir_
+
self
.
_model_prefix_
+
"_"
+
str
(
i
)
+
"_newest_episodic_memory_sub_net_"
+
str
(
j
+
1
)
+
"-0000"
)
except
OSError
:
pass
if
os
.
path
.
isdir
(
self
.
_model_dir_
):
for
file
in
os
.
listdir
(
self
.
_model_dir_
):
if
".params"
in
file
and
self
.
_model_prefix_
+
"_"
+
str
(
i
)
in
file
:
epochStr
=
file
.
replace
(
".params"
,
""
).
replace
(
self
.
_model_prefix_
+
"_"
+
str
(
i
)
+
"-"
,
""
)
if
".params"
in
file
and
self
.
_model_prefix_
+
"_"
+
str
(
i
)
in
file
and
not
"loss"
in
file
:
epochStr
=
file
.
replace
(
".params"
,
""
).
replace
(
self
.
_model_prefix_
+
"_"
+
str
(
i
)
+
"-"
,
""
)
epoch
=
int
(
epochStr
)
if
epoch
>
lastEpoch
:
if
epoch
>
=
lastEpoch
:
lastEpoch
=
epoch
param_file
=
file
elif
hasattr
(
network
,
'episodic_sub_nets'
)
and
self
.
_model_prefix_
+
"_"
+
str
(
i
)
+
"_episodic_memory_sub_net_"
in
file
:
relMemPathInfo
=
file
.
replace
(
self
.
_model_prefix_
+
"_"
+
str
(
i
)
+
"_episodic_memory_sub_net_"
,
""
).
split
(
"-"
)
memSubNet
=
int
(
relMemPathInfo
[
0
])
memEpochStr
=
relMemPathInfo
[
1
]
memEpoch
=
int
(
memEpochStr
)
if
memEpoch
>=
lastMemEpoch
[
memSubNet
-
1
]:
lastMemEpoch
[
memSubNet
-
1
]
=
memEpoch
mem_files
[
memSubNet
-
1
]
=
file
if
param_file
is
None
:
earliestLastEpoch
=
0
else
:
logging
.
info
(
"Loading checkpoint: "
+
param_file
)
network
.
load_parameters
(
self
.
_model_dir_
+
param_file
)
if
hasattr
(
network
,
'episodic_sub_nets'
):
for
j
,
sub_net
in
enumerate
(
network
.
episodic_sub_nets
):
if
mem_files
[
j
]
!=
None
:
logging
.
info
(
"Loading Replay Memory: "
+
mem_files
[
j
])
mem_layer
=
[
param
for
param
in
inspect
.
getmembers
(
sub_net
,
lambda
x
:
not
(
inspect
.
isroutine
(
x
)))
if
param
[
0
].
startswith
(
"memory"
)][
0
][
1
]
mem_layer
.
load_memory
(
self
.
_model_dir_
+
mem_files
[
j
])
if
earliestLastEpoch
==
None
or
lastEpoch
<
earliestLastEpoch
:
earliestLastEpoch
=
lastEpoch
if
earliestLastEpoch
==
None
or
lastEpoch
+
1
<
earliestLastEpoch
:
earliestLastEpoch
=
lastEpoch
+
1
return
earliestLastEpoch
...
...
@@ -56,27 +117,52 @@ class CNNCreator_mnist_mnistClassifier_net:
for
i
,
network
in
self
.
networks
.
items
():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file
=
None
if
hasattr
(
network
,
'episodic_sub_nets'
):
num_episodic_sub_nets
=
len
(
network
.
episodic_sub_nets
)
lastMemEpoch
=
[
0
]
*
num_episodic_sub_nets
mem_files
=
[
None
]
*
num_episodic_sub_nets
if
os
.
path
.
isdir
(
self
.
_weights_dir_
):
lastEpoch
=
0
for
file
in
os
.
listdir
(
self
.
_weights_dir_
):
if
".params"
in
file
and
self
.
_model_prefix_
+
"_"
+
str
(
i
)
in
file
:
if
".params"
in
file
and
self
.
_model_prefix_
+
"_"
+
str
(
i
)
in
file
and
not
"loss"
in
file
:
epochStr
=
file
.
replace
(
".params"
,
""
).
replace
(
self
.
_model_prefix_
+
"_"
+
str
(
i
)
+
"-"
,
""
)
epoch
=
int
(
epochStr
)
if
epoch
>
lastEpoch
:
if
epoch
>
=
lastEpoch
:
lastEpoch
=
epoch
param_file
=
file
elif
hasattr
(
network
,
'episodic_sub_nets'
)
and
self
.
_model_prefix_
+
"_"
+
str
(
i
)
+
"_episodic_memory_sub_net_"
in
file
:
relMemPathInfo
=
file
.
replace
(
self
.
_model_prefix_
+
"_"
+
str
(
i
)
+
"_episodic_memory_sub_net_"
).
split
(
"-"
)
memSubNet
=
int
(
relMemPathInfo
[
0
])
memEpochStr
=
relMemPathInfo
[
1
]
memEpoch
=
int
(
memEpochStr
)
if
memEpoch
>=
lastMemEpoch
[
memSubNet
-
1
]:
lastMemEpoch
[
memSubNet
-
1
]
=
memEpoch
mem_files
[
memSubNet
-
1
]
=
file
logging
.
info
(
"Loading pretrained weights: "
+
self
.
_weights_dir_
+
param_file
)
network
.
load_parameters
(
self
.
_weights_dir_
+
param_file
,
allow_missing
=
True
,
ignore_extra
=
True
)
if
hasattr
(
network
,
'episodic_sub_nets'
):
assert
lastEpoch
==
lastMemEpoch
for
j
,
sub_net
in
enumerate
(
network
.
episodic_sub_nets
):
if
mem_files
[
j
]
!=
None
:
logging
.
info
(
"Loading pretrained Replay Memory: "
+
mem_files
[
j
])
mem_layer
=
\
[
param
for
param
in
inspect
.
getmembers
(
sub_net
,
lambda
x
:
not
(
inspect
.
isroutine
(
x
)))
if
param
[
0
].
startswith
(
"memory"
)][
0
][
1
]
mem_layer
.
load_memory
(
self
.
_model_dir_
+
mem_files
[
j
])
else
:
logging
.
info
(
"No pretrained weights available at: "
+
self
.
_weights_dir_
+
param_file
)
def
construct
(
self
,
context
,
data_mean
=
None
,
data_std
=
None
):
self
.
networks
[
0
]
=
Net_0
(
data_mean
=
data_mean
,
data_std
=
data_std
)
self
.
networks
[
0
].
collect_params
().
initialize
(
self
.
weight_initializer
,
ctx
=
context
)
self
.
networks
[
0
]
=
Net_0
(
data_mean
=
data_mean
,
data_std
=
data_std
,
mx_context
=
context
,
prefix
=
""
)
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
)
self
.
networks
[
0
].
collect_params
().
initialize
(
self
.
weight_initializer
,
force_reinit
=
False
,
ctx
=
context
)
self
.
networks
[
0
].
hybridize
()
self
.
networks
[
0
](
mx
.
nd
.
zeros
((
1
,
1
,
28
,
28
,),
ctx
=
context
))
self
.
networks
[
0
](
mx
.
nd
.
zeros
((
1
,
1
,
28
,
28
,),
ctx
=
context
[
0
]
))
if
not
os
.
path
.
exists
(
self
.
_model_dir_
):
os
.
makedirs
(
self
.
_model_dir_
)
...
...
src/test/resources/target_code/gluon/CNNNet_mnist_mnistClassifier_net.py
View file @
2102b34d
import
mxnet
as
mx
import
numpy
as
np
import
math
from
mxnet
import
gluon
import
os
import
abc
import
warnings
from
mxnet
import
gluon
,
nd
class
ZScoreNormalization
(
gluon
.
HybridBlock
):
...
...
@@ -86,9 +89,419 @@ class CustomGRU(gluon.HybridBlock):
output
,
[
state0
]
=
self
.
gru
(
data
,
[
F
.
swapaxes
(
state0
,
0
,
1
)])
return
output
,
F
.
swapaxes
(
state0
,
0
,
1
)
class
DotProductSelfAttention
(
gluon
.
HybridBlock
):
def
__init__
(
self
,
scale_factor
,
num_heads
,
dim_model
,
dim_keys
,
dim_values
,
use_proj_bias
,
use_mask
,
**
kwargs
):
super
(
DotProductSelfAttention
,
self
).
__init__
(
**
kwargs
)
with
self
.
name_scope
():
self
.
num_heads
=
num_heads
self
.
dim_model
=
dim_model
self
.
use_proj_bias
=
use_proj_bias
self
.
use_mask
=
use_mask
if
dim_keys
==
-
1
:
self
.
dim_keys
=
int
(
dim_model
/
self
.
num_heads
)
else
:
self
.
dim_keys
=
dim_keys
if
dim_values
==
-
1
:
self
.
dim_values
=
int
(
dim_model
/
self
.
num_heads
)
else
:
self
.
dim_values
=
dim_values
if
scale_factor
==
-
1
:
self
.
scale_factor
=
math
.
sqrt
(
self
.
dim_keys
)
else
:
self
.
scale_factor
=
scale_factor
self
.
proj_q
=
gluon
.
nn
.
Dense
(
self
.
num_heads
*
self
.
dim_keys
,
use_bias
=
self
.
use_proj_bias
,
flatten
=
False
)
self
.
proj_k
=
gluon
.
nn
.
Dense
(
self
.
num_heads
*
self
.
dim_keys
,
use_bias
=
self
.
use_proj_bias
,
flatten
=
False
)
self
.
proj_v
=
gluon
.
nn
.
Dense
(
self
.
num_heads
*
self
.
dim_values
,
use_bias
=
self
.
use_proj_bias
,
flatten
=
False
)
self
.
proj_o
=
gluon
.
nn
.
Dense
(
self
.
dim_model
,
use_bias
=
self
.
use_proj_bias
,
flatten
=
False
)
def
hybrid_forward
(
self
,
F
,
queries
,
keys
,
values
,
*
args
,
**
kwargs
):
queries
=
F
.
Reshape
(
queries
,
shape
=
(
0
,
0
,
-
1
))
keys
=
F
.
Reshape
(
queries
,
shape
=
(
0
,
0
,
-
1
))
values
=
F
.
Reshape
(
queries
,
shape
=
(
0
,
0
,
-
1
))
head_queries
=
self
.
proj_q
(
queries
)
head_keys
=
self
.
proj_k
(
keys
)
head_values
=
self
.
proj_v
(
values
)
head_queries
=
F
.
reshape
(
head_queries
,
shape
=
(
0
,
0
,
self
.
num_heads
,
-
1
))
head_queries
=
F
.
transpose
(
head_queries
,
axes
=
(
0
,
2
,
1
,
3
))
head_queries
=
F
.
reshape
(
head_queries
,
shape
=
(
-
1
,
0
,
0
),
reverse
=
True
)
head_keys
=
F
.
reshape
(
head_keys
,
shape
=
(
0
,
0
,
self
.
num_heads
,
-
1
))
head_keys
=
F
.
transpose
(
head_keys
,
axes
=
(
0
,
2
,
1
,
3
))
head_keys
=
F
.
reshape
(
head_keys
,
shape
=
(
-
1
,
0
,
0
),
reverse
=
True
)
score
=
F
.
batch_dot
(
head_queries
,
head_keys
,
transpose_b
=
True
)
score
=
score
*
self
.
scale_factor
if
self
.
use_mask
:
mask
=
F
.
tile
(
mask
,
self
.
num_heads
)
mask
=
F
.
repeat
(
mask
,
self
.
dim_model
)
mask
=
F
.
reshape
(
mask
,
shape
=
(
-
1
,
self
.
dim_model
))
weights
=
F
.
softmax
(
score
,
mask
,
use_length
=
self
.
use_mask
)
head_values
=
F
.
reshape
(
head_values
,
shape
=
(
0
,
0
,
self
.
num_heads
,
-
1
))
head_values
=
F
.
transpose
(
head_values
,
axes
=
(
0
,
2
,
1
,
3
))
head_values
=
F
.
reshape
(
head_values
,
shape
=
(
-
1
,
0
,
0
),
reverse
=
True
)
ret
=
F
.
batch_dot
(
weights
,
head_values
)
ret
=
F
.
reshape
(
ret
,
shape
=
(
-
1
,
self
.
num_heads
,
0
,
0
),
reverse
=
True
)
ret
=
F
.
transpose
(
ret
,
axes
=
(
0
,
2
,
1
,
3
))
ret
=
F
.
reshape
(
ret
,
shape
=
(
0
,
0
,
-
1
))
ret
=
self
.
proj_o
(
ret
)
return
ret
class
EpisodicReplayMemoryInterface
(
gluon
.
HybridBlock
):
__metaclass__
=
abc
.
ABCMeta
def
__init__
(
self
,
use_replay
,
replay_interval
,
replay_batch_size
,
replay_steps
,
replay_gradient_steps
,
num_heads
,
**
kwargs
):
super
(
EpisodicReplayMemoryInterface
,
self
).
__init__
(
**
kwargs
)
self
.
use_replay
=
use_replay
self
.
replay_interval
=
replay_interval
self
.
replay_batch_size
=
replay_batch_size
self
.
replay_steps
=
replay_steps
self
.
replay_gradient_steps
=
replay_gradient_steps
self
.
num_heads
=
num_heads
@
abc
.
abstractmethod
def
store_samples
(
self
,
data
,
y
,
query_network
,
store_prob
,
mx_context
):
pass
@
abc
.
abstractmethod
def
sample_memory
(
self
,
batch_size
,
mx_context
):
pass
@
abc
.
abstractmethod
def
get_query_network
(
self
,
mx_context
):
pass
@
abc
.
abstractmethod
def
save_memory
(
self
,
path
):
pass
@
abc
.
abstractmethod
def
load_memory
(
self
,
path
):
pass
#Memory layer
class
LargeMemory
(
gluon
.
HybridBlock
):
def
__init__
(
self
,
sub_key_size
,
query_size
,
query_act
,
dist_measure
,
k
,
num_heads
,
values_dim
,
**
kwargs
):
super
(
LargeMemory
,
self
).
__init__
(
**
kwargs
)
with
self
.
name_scope
():
#Memory parameters
self
.
dist_measure
=
dist_measure
self
.
k
=
k
self
.
num_heads
=
num_heads
self
.
query_act
=
query_act
self
.
query_size
=
query_size
self
.
num_heads
=
num_heads
#Batch norm sub-layer
self
.
batch_norm
=
gluon
.
nn
.
BatchNorm
()
#Memory sub-layer
self
.
sub_key_size
=
sub_key_size
sub_key_shape
=
(
self
.
num_heads
,
self
.
sub_key_size
,
int
(
query_size
[
-
1
]
/
2
))
if
values_dim
==
-
1
:
values_shape
=
(
self
.
sub_key_size
*
self
.
sub_key_size
,
self
.
query_size
[
-
1
])
else
:
values_shape
=
(
self
.
sub_key_size
*
self
.
sub_key_size
,
values_dim
)
self
.
sub_keys1
=
self
.
params
.
get
(
"sub_keys1"
,
shape
=
sub_key_shape
,
differentiable
=
True
)
self
.
sub_keys2
=
self
.
params
.
get
(
"sub_keys2"
,
shape
=
sub_key_shape
,
differentiable
=
True
)
self
.
values
=
self
.
params
.
get
(
"values"
,
shape
=
values_shape
,
differentiable
=
True
)
self
.
label_memory
=
nd
.
array
([])
self
.
get_query_network
()
def
hybrid_forward
(
self
,
F
,
x
,
sub_keys1
,
sub_keys2
,
values
):
x
=
self
.
batch_norm
(
x
)
x
=
F
.
reshape
(
x
,
shape
=
(
0
,
-
1
))
q
=
self
.
query_network
(
x
)
q
=
F
.
reshape
(
q
,
shape
=
(
0
,
self
.
num_heads
,
-
1
))
q_split
=
F
.
split
(
q
,
num_outputs
=
2
,
axis
=-
1
)
if
self
.
dist_measure
==
"l2"
:
q_split_resh
=
F
.
reshape
(
q_split
[
0
],
shape
=
(
0
,
0
,
1
,
-
1
))
sub_keys1_resh
=
F
.
reshape
(
sub_keys1
,
shape
=
(
1
,
0
,
0
,
-
1
),
reverse
=
True
)
q1_diff
=
F
.
broadcast_sub
(
q_split_resh
,
sub_keys1_resh
)
q1_dist
=
F
.
norm
(
q1_diff
,
axis
=-
1
)
q_split_resh
=
F
.
reshape
(
q_split
[
1
],
shape
=
(
0
,
0
,
1
,
-
1
))
sub_keys2_resh
=
F
.
reshape
(
sub_keys2
,
shape
=
(
1
,
0
,
0
,
-
1
),
reverse
=
True
)
q2_diff
=
F
.
broadcast_sub
(
q_split_resh
,
sub_keys2_resh
)
q2_dist
=
F
.
norm
(
q2_diff
,
axis
=-
1
)
else
:
q1
=
F
.
split
(
q_split
[
0
],
num_outputs
=
self
.
num_heads
,
axis
=
1
)
q2
=
F
.
split
(
q_split
[
1
],
num_outputs
=
self
.
num_heads
,
axis
=
1
)
sub_keys1_resh
=
F
.
split
(
sub_keys1
,
num_outputs
=
self
.
num_heads
,
axis
=
0
,
squeeze_axis
=
True
)
sub_keys2_resh
=
F
.
split
(
sub_keys2
,
num_outputs
=
self
.
num_heads
,
axis
=
0
,
squeeze_axis
=
True
)
if
self
.
num_heads
==
1
:
q1
=
[
q1
]
q2
=
[
q2
]
sub_keys1_resh
=
[
sub_keys1_resh
]
sub_keys2_resh
=
[
sub_keys2_resh
]
q1_dist
=
F
.
dot
(
q1
[
0
],
sub_keys1_resh
[
0
],
transpose_b
=
True
)
q2_dist
=
F
.
dot
(
q2
[
0
],
sub_keys2_resh
[
0
],
transpose_b
=
True
)
for
h
in
range
(
1
,
self
.
num_heads
):
q1_dist
=
F
.
concat
(
q1_dist
,
F
.
dot
(
q1
[
0
],
sub_keys1_resh
[
h
],
transpose_b
=
True
),
dim
=
1
)
q2_dist
=
F
.
concat
(
q2_dist
,
F
.
dot
(
q2
[
0
],
sub_keys1_resh
[
h
],
transpose_b
=
True
),
dim
=
1
)
i1
=
F
.
topk
(
q1_dist
,
k
=
self
.
k
,
ret_typ
=
"indices"
)
i2
=
F
.
topk
(
q2_dist
,
k
=
self
.
k
,
ret_typ
=
"indices"
)
# Calculate cross product for keys at indices I1 and I2
# def head_take(data, state):
# return [F.take(data[0], data[2]), F.take(data[1], data[3])], state,
#
# i1 = F.transpose(i1, axes=(1,0,2))
# i2 = F.transpose(i2, axes=(1, 0, 2))
# st = F.zeros(1)
# (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st)
# k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True)
# k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True)
i1
=
F
.
split
(
i1
,
num_outputs
=
self
.
num_heads
,
axis
=
1
)
i2
=
F
.
split
(
i2
,
num_outputs
=
self
.
num_heads
,
axis
=
1
)
sub_keys1
=
F
.
split
(
sub_keys1
,
num_outputs
=
self
.
num_heads
,
axis
=
0
,
squeeze_axis
=
True
)
sub_keys2
=
F
.
split
(
sub_keys2
,
num_outputs
=
self
.
num_heads
,
axis
=
0
,
squeeze_axis
=
True
)
if
self
.
num_heads
==
1
:
i1
=
[
i1
]
i2
=
[
i2
]
sub_keys1
=
[
sub_keys1
]
sub_keys2
=
[
sub_keys2
]
k1
=
F
.
take
(
sub_keys1
[
0
],
i1
[
0
])
k2
=
F
.
take
(
sub_keys2
[
0
],
i2
[
0
])
for
h
in
range
(
1
,
self
.
num_heads
):
k1
=
F
.
concat
(
k1
,
F
.
take
(
sub_keys1
[
h
],
i1
[
h
]),
dim
=
1
)
k2
=
F
.
concat
(
k2
,
F
.
take
(
sub_keys2
[
h
],
i2
[
h
]),
dim
=
1
)
k1
=
F
.
tile
(
k1
,
(
1
,
1
,
self
.
k
,
1
))
k2
=
F
.
repeat
(
k2
,
self
.
k
,
2
)
c_cart
=
F
.
concat
(
k1
,
k2
,
dim
=
3
)
q
=
F
.
reshape
(
q
,
shape
=
(
-
1
,
0
),
reverse
=
True
)
q
=
F
.
reshape
(
q
,
shape
=
(
0
,
1
,
-
1
))
c_cart
=
F
.
reshape
(
c_cart
,
shape
=
(
-
1
,
0
,
0
),
reverse
=
True
)