Commit 2102b34d authored by Julian Johannes Steinsberger-Dührßen's avatar Julian Johannes Steinsberger-Dührßen
Browse files

bug fixes, added tests for EpisodicMemory, increased version number

parent 7e9f1b0c
......@@ -9,19 +9,19 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>embedded-montiarc-emadl-generator</artifactId>
<version>0.4.0</version>
<version>0.4.1</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<emadl.version>0.2.11-SNAPSHOT</emadl.version>
<CNNTrain.version>0.3.11-SNAPSHOT</CNNTrain.version>
<cnnarch-generator.version>0.0.6-SNAPSHOT</cnnarch-generator.version>
<emadl.version>0.2.12-SNAPSHOT</emadl.version>
<CNNTrain.version>0.3.12-SNAPSHOT</CNNTrain.version>
<cnnarch-generator.version>0.0.7-SNAPSHOT</cnnarch-generator.version>
<cnnarch-mxnet-generator.version>0.2.17-SNAPSHOT</cnnarch-mxnet-generator.version>
<cnnarch-caffe2-generator.version>0.2.14-SNAPSHOT</cnnarch-caffe2-generator.version>
<cnnarch-gluon-generator.version>0.2.11-SNAPSHOT</cnnarch-gluon-generator.version>
<cnnarch-gluon-generator.version>0.2.12-SNAPSHOT</cnnarch-gluon-generator.version>
<cnnarch-tensorflow-generator.version>0.1.0-SNAPSHOT</cnnarch-tensorflow-generator.version>
<Common-MontiCar.version>0.0.19-SNAPSHOT</Common-MontiCar.version>
<embedded-montiarc-math-opt-generator>0.1.6</embedded-montiarc-math-opt-generator>
......
......@@ -106,6 +106,13 @@ public class GenerationTest extends AbstractSymtabTest {
assertTrue(Log.getFindings().isEmpty());
}
@Test
public void testEpisodicMemorySimpleGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models", "-r", "episodicMemorySimple.Network", "-b", "GLUON", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
}
@Test
public void testMultipleInstances() throws IOException, TemplateException {
try {
......@@ -183,7 +190,6 @@ public class GenerationTest extends AbstractSymtabTest {
"CNNPredictor_mnist_mnistClassifier_net.h",
"CNNDataLoader_mnist_mnistClassifier_net.py",
"CNNSupervisedTrainer_mnist_mnistClassifier_net.py",
"mnist_mnistClassifier_net.h",
"HelperA.h",
"CNNTranslator.h",
"mnist_mnistClassifier_calculateClass.h",
......@@ -300,9 +306,6 @@ public class GenerationTest extends AbstractSymtabTest {
"CNNTrainer_defaultGAN_defaultGANConnector_predictor.py",
"defaultGAN_defaultGANConnector.cpp",
"defaultGAN_defaultGANConnector.h",
"defaultGAN_defaultGANConnector_predictor.h",
"defaultGAN_defaultGANConnector.cpp",
"defaultGAN_defaultGANConnector.h",
"defaultGAN_defaultGANConnector_predictor.h"
)
);
......@@ -361,7 +364,7 @@ public class GenerationTest extends AbstractSymtabTest {
EMADLGeneratorCli.main(args);
assertEquals(Log.getFindings().size(), 1);
assertEquals(Log.getFindings().get(0).toString(),
"Tagging info for symbol was found, ignoring data_paths.txt: src/test/resources/models");
"Tagging info for DataPath symbol was found, ignoring data_paths.txt: src/test/resources/models");
assertTrue(Log.getErrorCount() == 0);
}
......
......@@ -70,6 +70,16 @@ public class IntegrationGluonTest extends IntegrationTest {
assertTrue(Log.getFindings().isEmpty());
}
@Test
public void testEpisodicMemorySimple() {
Log.getFindings().clear();
deleteHashFile(Paths.get("./target/generated-sources-emadl/episodicMemorySimple/episodicMemorySimple.training_hash"));
String[] args = {"-m", "src/test/resources/models", "-r", "episodicMemorySimple.Network", "-b", "GLUON"};
EMADLGeneratorCli.main(args);
}
@Test
public void testGluonPreprocessingWithSupervised() {
Log.getFindings().clear();
......
/* (c) https://github.com/MontiCore/monticore */
configuration Network{
num_epoch:1
batch_size:5
normalize:false
context:cpu
load_checkpoint:false
loss:cross_entropy
optimizer:adam{
learning_rate:0.00003
weight_decay:0.01
}
}
/* (c) https://github.com/MontiCore/monticore */
package episodicMemorySimple;
component Network{
ports in Z(0:oo)^{10} data,
out Q(0:1)^{33} softmax;
implementation CNN {
data ->
EpisodicMemory(replayInterval=10, replayBatchSize=100, replaySteps=1, replayGradientSteps=1, replayMemoryStoreProb=0.5, localAdaptionGradientSteps=30, maxStoredSamples=-1, localAdaptionK=32, queryNetDir="tag:simple", queryNetPrefix="simple_embedding-", queryNetNumInputs=1) ->
LoadNetwork(networkDir="tag:simple", networkPrefix="simple_embedding-", numInputs=1, outputShape=(1,768)) ->
FullyConnected(units=33) ->
Softmax() ->
softmax;
}
}
/* (c) https://github.com/MontiCore/monticore */
package episodicMemorySimple;
conforms to dltag.DataPathTagSchema, dltag.LayerPathParameterTagSchema;
tags episodic {
tag Network with DataPath = {path = src/test/resources/training_data/episodicMemorySimple, type = HDF5};
tag Network with LayerPathParameter = {path = src/test/resources/pretrained/episodicMemorySimple, id = simple};
}
{
"nodes": [
{
"op": "null",
"name": "data",
"inputs": []
},
{
"op": "_copy",
"name": "simpleembedding0_identity0",
"inputs": [[0, 0, 0]]
}
],
"arg_nodes": [0],
"node_row_ptr": [0, 1, 2],
"heads": [[1, 0, 0]],
"attrs": {"mxnet_version": ["int", 10501]}
}
\ No newline at end of file
......@@ -2,6 +2,8 @@ import mxnet as mx
import logging
import os
import shutil
import warnings
import inspect
from CNNNet_mnist_mnistClassifier_net import Net_0
......@@ -20,6 +22,10 @@ class CNNCreator_mnist_mnistClassifier_net:
for i, network in self.networks.items():
lastEpoch = 0
param_file = None
if hasattr(network, 'episodic_sub_nets'):
num_episodic_sub_nets = len(network.episodic_sub_nets)
lastMemEpoch = [0]*num_episodic_sub_nets
mem_files = [None]*num_episodic_sub_nets
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
......@@ -30,22 +36,77 @@ class CNNCreator_mnist_mnistClassifier_net:
except OSError:
pass
if hasattr(network, 'episodic_sub_nets'):
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json")
except OSError:
pass
for j in range(len(network.episodic_sub_nets)):
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000")
except OSError:
pass
if os.path.isdir(self._model_dir_):
for file in os.listdir(self._model_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "")
epoch = int(epochStr)
if epoch > lastEpoch:
if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-")
memSubNet = int(relMemPathInfo[0])
memEpochStr = relMemPathInfo[1]
memEpoch = int(memEpochStr)
if memEpoch >= lastMemEpoch[memSubNet-1]:
lastMemEpoch[memSubNet-1] = memEpoch
mem_files[memSubNet-1] = file
if param_file is None:
earliestLastEpoch = 0
else:
logging.info("Loading checkpoint: " + param_file)
network.load_parameters(self._model_dir_ + param_file)
if hasattr(network, 'episodic_sub_nets'):
for j, sub_net in enumerate(network.episodic_sub_nets):
if mem_files[j] != None:
logging.info("Loading Replay Memory: " + mem_files[j])
mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1]
mem_layer.load_memory(self._model_dir_ + mem_files[j])
if earliestLastEpoch == None or lastEpoch < earliestLastEpoch:
earliestLastEpoch = lastEpoch
if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch:
earliestLastEpoch = lastEpoch + 1
return earliestLastEpoch
......@@ -56,27 +117,52 @@ class CNNCreator_mnist_mnistClassifier_net:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
if hasattr(network, 'episodic_sub_nets'):
num_episodic_sub_nets = len(network.episodic_sub_nets)
lastMemEpoch = [0] * num_episodic_sub_nets
mem_files = [None] * num_episodic_sub_nets
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-")
memSubNet = int(relMemPathInfo[0])
memEpochStr = relMemPathInfo[1]
memEpoch = int(memEpochStr)
if memEpoch >= lastMemEpoch[memSubNet-1]:
lastMemEpoch[memSubNet-1] = memEpoch
mem_files[memSubNet-1] = file
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
if hasattr(network, 'episodic_sub_nets'):
assert lastEpoch == lastMemEpoch
for j, sub_net in enumerate(network.episodic_sub_nets):
if mem_files[j] != None:
logging.info("Loading pretrained Replay Memory: " + mem_files[j])
mem_layer = \
[param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if
param[0].startswith("memory")][0][1]
mem_layer.load_memory(self._model_dir_ + mem_files[j])
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context)
self.networks[0].hybridize()
self.networks[0](mx.nd.zeros((1, 1,28,28,), ctx=context))
self.networks[0](mx.nd.zeros((1, 1,28,28,), ctx=context[0]))
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
......
import mxnet as mx
import numpy as np
import math
from mxnet import gluon
import os
import abc
import warnings
from mxnet import gluon, nd
class ZScoreNormalization(gluon.HybridBlock):
......@@ -86,9 +89,419 @@ class CustomGRU(gluon.HybridBlock):
output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)])
return output, F.swapaxes(state0, 0, 1)
class DotProductSelfAttention(gluon.HybridBlock):
def __init__(self,
scale_factor,
num_heads,
dim_model,
dim_keys,
dim_values,
use_proj_bias,
use_mask,
**kwargs):
super(DotProductSelfAttention, self).__init__(**kwargs)
with self.name_scope():
self.num_heads = num_heads
self.dim_model = dim_model
self.use_proj_bias = use_proj_bias
self.use_mask = use_mask
if dim_keys == -1:
self.dim_keys = int(dim_model / self.num_heads)
else:
self.dim_keys = dim_keys
if dim_values == -1:
self.dim_values = int(dim_model / self.num_heads)
else:
self.dim_values = dim_values
if scale_factor == -1:
self.scale_factor = math.sqrt(self.dim_keys)
else:
self.scale_factor = scale_factor
self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False)
self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False)
self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False)
self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False)
def hybrid_forward(self, F, queries, keys, values, *args, **kwargs):
queries = F.Reshape(queries, shape=(0, 0,-1))
keys = F.Reshape(queries, shape=(0, 0, -1))
values = F.Reshape(queries, shape=(0, 0, -1))
head_queries = self.proj_q(queries)
head_keys = self.proj_k(keys)
head_values = self.proj_v(values)
head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1))
head_queries = F.transpose(head_queries, axes=(0,2,1,3))
head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True)
head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1))
head_keys = F.transpose(head_keys, axes=(0,2,1,3))
head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True)
score = F.batch_dot(head_queries, head_keys, transpose_b=True)
score = score * self.scale_factor
if self.use_mask:
mask = F.tile(mask, self.num_heads)
mask = F.repeat(mask, self.dim_model)
mask = F.reshape(mask, shape=(-1, self.dim_model))
weights = F.softmax(score, mask, use_length=self.use_mask)
head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1))
head_values = F.transpose(head_values, axes=(0,2,1,3))
head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True)
ret = F.batch_dot(weights, head_values)
ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True)
ret = F.transpose(ret, axes=(0, 2, 1, 3))
ret = F.reshape(ret, shape=(0, 0, -1))
ret = self.proj_o(ret)
return ret
class EpisodicReplayMemoryInterface(gluon.HybridBlock):
__metaclass__ = abc.ABCMeta
def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs):
super(EpisodicReplayMemoryInterface, self).__init__(**kwargs)
self.use_replay = use_replay
self.replay_interval = replay_interval
self.replay_batch_size = replay_batch_size
self.replay_steps = replay_steps
self.replay_gradient_steps = replay_gradient_steps
self.num_heads = num_heads
@abc.abstractmethod
def store_samples(self, data, y, query_network, store_prob, mx_context):
pass
@abc.abstractmethod
def sample_memory(self, batch_size, mx_context):
pass
@abc.abstractmethod
def get_query_network(self, mx_context):
pass
@abc.abstractmethod
def save_memory(self, path):
pass
@abc.abstractmethod
def load_memory(self, path):
pass
#Memory layer
class LargeMemory(gluon.HybridBlock):
def __init__(self,
sub_key_size,
query_size,
query_act,
dist_measure,
k,
num_heads,
values_dim,
**kwargs):
super(LargeMemory, self).__init__(**kwargs)
with self.name_scope():
#Memory parameters
self.dist_measure = dist_measure
self.k = k
self.num_heads = num_heads
self.query_act = query_act
self.query_size = query_size
self.num_heads = num_heads
#Batch norm sub-layer
self.batch_norm = gluon.nn.BatchNorm()
#Memory sub-layer
self.sub_key_size = sub_key_size
sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2))
if values_dim == -1:
values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1])
else:
values_shape = (self.sub_key_size*self.sub_key_size, values_dim)
self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True)
self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True)
self.values = self.params.get("values", shape=values_shape, differentiable=True)
self.label_memory = nd.array([])
self.get_query_network()
def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values):
x = self.batch_norm(x)
x = F.reshape(x, shape=(0, -1))
q = self.query_network(x)
q = F.reshape(q, shape=(0, self.num_heads, -1))
q_split = F.split(q, num_outputs=2, axis=-1)
if self.dist_measure == "l2":
q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1))
sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True)
q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh)
q1_dist = F.norm(q1_diff, axis=-1)
q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1))
sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True)
q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh)
q2_dist = F.norm(q2_diff, axis=-1)
else:
q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1)
q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1)
sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
if self.num_heads == 1:
q1 = [q1]
q2 = [q2]
sub_keys1_resh = [sub_keys1_resh ]
sub_keys2_resh = [sub_keys2_resh ]
q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True)
q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True)
for h in range(1, self.num_heads):
q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1)
q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1)
i1 = F.topk(q1_dist, k=self.k, ret_typ="indices")
i2 = F.topk(q2_dist, k=self.k, ret_typ="indices")
# Calculate cross product for keys at indices I1 and I2
# def head_take(data, state):
# return [F.take(data[0], data[2]), F.take(data[1], data[3])], state,
#
# i1 = F.transpose(i1, axes=(1,0,2))
# i2 = F.transpose(i2, axes=(1, 0, 2))
# st = F.zeros(1)
# (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st)
# k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True)
# k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True)
i1 = F.split(i1, num_outputs=self.num_heads, axis=1)
i2 = F.split(i2, num_outputs=self.num_heads, axis=1)
sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
if self.num_heads == 1:
i1 = [i1]
i2 = [i2]
sub_keys1 = [sub_keys1]
sub_keys2 = [sub_keys2]
k1 = F.take(sub_keys1[0], i1[0])
k2 = F.take(sub_keys2[0], i2[0])
for h in range(1, self.num_heads):
k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1)
k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1)
k1 = F.tile(k1, (1, 1, self.k, 1))
k2 = F.repeat(k2, self.k, 2)
c_cart = F.concat(k1, k2, dim=3)
q = F.reshape(q, shape=(-1,0), reverse=True)
q = F.reshape(q, shape=(0, 1, -1))
c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True)