CNNNet.ftl 27.9 KB
Newer Older
Bernhard Rumpe's avatar
BR-sy    
Bernhard Rumpe committed
1
<#-- (c) https://github.com/MontiCore/monticore -->
Nicola Gatto's avatar
Nicola Gatto committed
2
3
import mxnet as mx
import numpy as np
4
import math
5
6
import os
import abc
7
import warnings
8
import sys
9
from mxnet import gluon, nd
10
11
12
<#if tc.containsAdaNet()>
from mxnet.gluon import nn, HybridBlock
from numpy import log, product,prod
lr119628's avatar
lr119628 committed
13
from mxnet.ndarray import zeros,zeros_like
14
</#if>
15
16
17
18
<#if tc.architecture.customPyFilesPath??>
sys.path.insert(1, '${tc.architecture.customPyFilesPath}')
from custom_layers import *
</#if>
Nicola Gatto's avatar
Nicola Gatto committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

class ZScoreNormalization(gluon.HybridBlock):
    def __init__(self, data_mean, data_std, **kwargs):
        super(ZScoreNormalization, self).__init__(**kwargs)
        with self.name_scope():
            self.data_mean = self.params.get('data_mean', shape=data_mean.shape,
                init=mx.init.Constant(data_mean.asnumpy().tolist()), differentiable=False)
            self.data_std = self.params.get('data_std', shape=data_mean.shape,
                init=mx.init.Constant(data_std.asnumpy().tolist()), differentiable=False)

    def hybrid_forward(self, F, x, data_mean, data_std):
        x = F.broadcast_sub(x, data_mean)
        x = F.broadcast_div(x, data_std)
        return x


class Padding(gluon.HybridBlock):
    def __init__(self, padding, **kwargs):
        super(Padding, self).__init__(**kwargs)
        with self.name_scope():
            self.pad_width = padding

    def hybrid_forward(self, F, x):
        x = F.pad(data=x,
            mode='constant',
            pad_width=self.pad_width,
            constant_value=0)
        return x


class NoNormalization(gluon.HybridBlock):
    def __init__(self, **kwargs):
        super(NoNormalization, self).__init__(**kwargs)

    def hybrid_forward(self, F, x):
        return x


Sebastian Nickels's avatar
Merge    
Sebastian Nickels committed
57
58
59
60
61
62
63
64
65
66
class Reshape(gluon.HybridBlock):
    def __init__(self, shape, **kwargs):
        super(Reshape, self).__init__(**kwargs)
        with self.name_scope():
            self.shape = shape

    def hybrid_forward(self, F, x):
        return F.reshape(data=x, shape=self.shape)


67
class CustomRNN(gluon.HybridBlock):
68
    def __init__(self, hidden_size, num_layers, dropout, bidirectional, **kwargs):
69
70
        super(CustomRNN, self).__init__(**kwargs)
        with self.name_scope():
71
            self.rnn = gluon.rnn.RNN(hidden_size=hidden_size, num_layers=num_layers, dropout=dropout,
72
73
74
75
76
77
78
79
                                     bidirectional=bidirectional, activation='tanh', layout='NTC')

    def hybrid_forward(self, F, data, state0):
        output, [state0] = self.rnn(data, [F.swapaxes(state0, 0, 1)])
        return output, F.swapaxes(state0, 0, 1)


class CustomLSTM(gluon.HybridBlock):
80
    def __init__(self, hidden_size, num_layers, dropout, bidirectional, **kwargs):
81
82
        super(CustomLSTM, self).__init__(**kwargs)
        with self.name_scope():
83
            self.lstm = gluon.rnn.LSTM(hidden_size=hidden_size, num_layers=num_layers, dropout=dropout,
84
85
86
87
88
89
90
91
                                       bidirectional=bidirectional, layout='NTC')

    def hybrid_forward(self, F, data, state0, state1):
        output, [state0, state1] = self.lstm(data, [F.swapaxes(state0, 0, 1), F.swapaxes(state1, 0, 1)])
        return output, F.swapaxes(state0, 0, 1), F.swapaxes(state1, 0, 1)


class CustomGRU(gluon.HybridBlock):
92
    def __init__(self, hidden_size, num_layers, dropout, bidirectional, **kwargs):
93
94
        super(CustomGRU, self).__init__(**kwargs)
        with self.name_scope():
95
            self.gru = gluon.rnn.GRU(hidden_size=hidden_size, num_layers=num_layers, dropout=dropout,
96
97
98
99
100
101
                                     bidirectional=bidirectional, layout='NTC')

    def hybrid_forward(self, F, data, state0):
        output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)])
        return output, F.swapaxes(state0, 0, 1)

102
    
103
104
105
106
107
108
109
110
class DotProductSelfAttention(gluon.HybridBlock):
    def __init__(self,
                 scale_factor,
                 num_heads,
                 dim_model,
                 dim_keys,
                 dim_values,
                 use_proj_bias,
111
                 use_mask,
112
113
114
115
116
117
                 **kwargs):
        super(DotProductSelfAttention, self).__init__(**kwargs)
        with self.name_scope():
            self.num_heads = num_heads
            self.dim_model = dim_model
            self.use_proj_bias = use_proj_bias
118
            self.use_mask = use_mask
119
120
121
122
123
124
125
126
127
128

            if dim_keys == -1:
                self.dim_keys = int(dim_model / self.num_heads)
            else:
                self.dim_keys = dim_keys
            if dim_values == -1:
                self.dim_values = int(dim_model / self.num_heads)
            else:
                self.dim_values = dim_values
    
129
130
            if scale_factor == -1:
                self.scale_factor = math.sqrt(self.dim_keys)
131
132
133
134
135
136
137
138
139
140
            else:
                self.scale_factor = scale_factor

            self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False)
            self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False)
            self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False)
            self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False)

    def hybrid_forward(self, F, queries, keys, values, *args, **kwargs):

141
142
143
144
        queries = F.Reshape(queries, shape=(0, 0,-1))
        keys = F.Reshape(queries, shape=(0, 0, -1))
        values = F.Reshape(queries, shape=(0, 0, -1))
    
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        head_queries = self.proj_q(queries)
        head_keys = self.proj_k(keys)
        head_values = self.proj_v(values)

        head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1))
        head_queries = F.transpose(head_queries, axes=(0,2,1,3))
        head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True)

        head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1))
        head_keys = F.transpose(head_keys, axes=(0,2,1,3))
        head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True)

        score = F.batch_dot(head_queries, head_keys, transpose_b=True)
        score = score * self.scale_factor
159

160
        if self.use_mask:
161
162
163
164
165
166
167
168
169
170
171
172
173
            seqs = F.contrib.arange_like(score, axis=1)
            zeros = F.zeros_like(seqs)
            zeros = F.reshape(zeros, shape=(1, -1))
            mask = args[0]
            mask = F.reshape(mask, shape=(-1, 1))
            mask = F.broadcast_add(mask, zeros)
            mask = F.expand_dims(mask, axis=1)
            mask = F.broadcast_axis(mask, axis=1, size=self.num_heads)
            mask = mask.reshape(shape=(-1, 0), reverse=True)
            mask = F.cast(mask, dtype='int32')
            weights = F.softmax(score, mask, use_length=self.use_mask)
        else:
            weights = F.softmax(score)
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188

        head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1))
        head_values = F.transpose(head_values, axes=(0,2,1,3))
        head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True)

        ret = F.batch_dot(weights, head_values)
        ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True)
        ret = F.transpose(ret, axes=(0, 2, 1, 3))
        ret = F.reshape(ret, shape=(0, 0, -1))

        ret = self.proj_o(ret)

        return ret

    
189
class EpisodicReplayMemoryInterface(gluon.HybridBlock):
190
191
    __metaclass__ = abc.ABCMeta

192
    def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, use_local_adaptation, local_adaptation_gradient_steps, k, **kwargs):
193
        super(EpisodicReplayMemoryInterface, self).__init__(**kwargs)
194

195
196
197
198
199
        self.use_replay = use_replay
        self.replay_interval = replay_interval
        self.replay_batch_size = replay_batch_size
        self.replay_steps = replay_steps
        self.replay_gradient_steps = replay_gradient_steps
200
201
202
203

        self.use_local_adaptation = use_local_adaptation
        self.local_adaptation_gradient_steps = local_adaptation_gradient_steps
        self.k = k
204
205
206
207
208
209
210
211
212

    @abc.abstractmethod
    def store_samples(self, data, y, query_network, store_prob, mx_context):
        pass

    @abc.abstractmethod
    def sample_memory(self, batch_size, mx_context):
        pass

213
214
215
216
    @abc.abstractmethod
    def sample_neighbours(self, data, query_network):
        pass

217
218
219
    @abc.abstractmethod
    def get_query_network(self, mx_context):
        pass   
220

221
222
223
    @abc.abstractmethod
    def save_memory(self, path):
        pass
224
225
226
227

    @abc.abstractmethod
    def load_memory(self, path):
        pass
228
    
229
#Memory layer
230
class LargeMemory(gluon.HybridBlock):
231
232
233
    def __init__(self, 
                 sub_key_size, 
                 query_size, 
234
                 query_act,
235
236
                 k, 
                 num_heads,
237
                 values_dim,
238
                 **kwargs):
239
        super(LargeMemory, self).__init__(**kwargs)
240
        with self.name_scope():
241
            #Memory parameters
242
            self.k = k
243
            self.num_heads = num_heads
244
245
            self.query_act = query_act
            self.query_size = query_size
246
247
            self.num_heads = num_heads
    
248
            #Batch norm sub-layer
249
250
            self.batch_norm = gluon.nn.BatchNorm()

251
            #Memory sub-layer
252
            self.sub_key_size = sub_key_size
253
            sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2))
254

255
256
257
258
            if values_dim == -1:
                values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1])
            else:
                values_shape = (self.sub_key_size*self.sub_key_size, values_dim)
259
260
261

            self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True)
            self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True)
262
            self.values = self.params.get("values", shape=values_shape, differentiable=True)
263
            self.label_memory = nd.array([])
264

265
            self.get_query_network()
266
267
268
269
                        
    def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values):
        x = self.batch_norm(x)

270
271
        x = F.reshape(x, shape=(0, -1))

272
        q = self.query_network(x)
273

274
275
        q = F.reshape(q, shape=(0, self.num_heads, -1))

276
277
        q_split = F.split(q, num_outputs=2, axis=-1)

278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
        q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1)
        q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1)
        sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
        sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
        if self.num_heads == 1:
            q1 = [q1]
            q2 = [q2]
            sub_keys1_resh = [sub_keys1_resh ]
            sub_keys2_resh = [sub_keys2_resh ]

        q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True)
        q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True)
        for h in range(1, self.num_heads):
           q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1)
           q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1)
293

294
295
        i1 = F.topk(q1_dist, k=self.k, ret_typ="indices")
        i2 = F.topk(q2_dist, k=self.k, ret_typ="indices")
296

297
298
299
300
        i1 = F.split(i1, num_outputs=self.num_heads, axis=1)
        i2 = F.split(i2, num_outputs=self.num_heads, axis=1)
        sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
        sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
301
        if self.num_heads == 1:
302
303
304
305
            i1 = [i1]
            i2 = [i2]
            sub_keys1 = [sub_keys1]
            sub_keys2 = [sub_keys2]
306

307
308
309
        k1 = F.take(sub_keys1[0], i1[0])
        k2 = F.take(sub_keys2[0], i2[0])
        for h in range(1, self.num_heads):
310
311
            k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1)
            k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1)
312

313
314
315
        k1 = F.tile(k1, (1, 1, self.k, 1))
        k2 = F.repeat(k2, self.k, 2)
        c_cart = F.concat(k1, k2, dim=3)
316

317
318
        q = F.reshape(q, shape=(-1,0), reverse=True)
        q = F.reshape(q, shape=(0, 1, -1))
319
        c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True)
320
321
322

        k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist)
        k_dist = F.reshape(k_dist, shape=(0, -1))
323

324
        i = F.topk(k_dist, k=self.k, ret_typ="both")
325

326
327
328
329
        w = F.softmax(i[0])
        w = F.reshape(w, shape=(0,1,-1))
        vi = F.take(values, i[1])
        aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist)
330

331
        ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True)
332
333
334
        one_vec = F.ones((1, 1, self.num_heads))
        one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0)
        ret = F.batch_dot(one_vec, ret)
335
        ret = F.reshape(ret, shape=(-1, 0), reverse=True)
336

337
338
339
        return ret

    def get_query_network(self):
340
341
342
        if hasattr(self, 'query_network'):
            return self.query_network
        else:
343
344
345
346
347
348
            self.query_network = gluon.nn.HybridSequential()
            for size in self.query_size:
                if self.query_act == "linear":
                    self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False))
                else:
                    self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False))
349
            return self.query_network
350
351


352
353
#EpisodicMemory layer
class EpisodicMemory(EpisodicReplayMemoryInterface):
354
355
356
357
358
359
360
    def __init__(self,
                 replay_interval,
                 replay_batch_size,
                 replay_steps,
                 replay_gradient_steps,
                 store_prob,
                 max_stored_samples,
361
                 memory_replacement_strategy,
362
                 use_replay,
363
364
365
                 use_local_adaptation,
                 local_adaptation_gradient_steps,
                 k,
366
367
                 query_net_dir,
                 query_net_prefix,
368
                 query_net_num_inputs,
369
                 **kwargs):
370
        super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, use_local_adaptation, local_adaptation_gradient_steps, k, **kwargs)
371
372
        with self.name_scope():
            #Replay parameters
373
            self.store_prob = store_prob
374
            self.max_stored_samples = max_stored_samples
375
376
            self.memory_replacement_strategy = memory_replacement_strategy

377
378
            self.query_net_dir = query_net_dir
            self.query_net_prefix = query_net_prefix
379
            self.query_net_num_inputs = query_net_num_inputs
380
    
381
            #Memory
382
            self.key_memory = nd.array([])
383
            self.value_memory = nd.array([])
384
            self.label_memory = nd.array([])
385

386
    def hybrid_forward(self, F, *args):
387
        #propagate the input as the rest is only used for replay
388
        return [args, []]
389

390
    def store_samples(self, data, y, query_network, store_prob, context):
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
        if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples):
            num_pus = len(data)
            sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)]
            num_inputs = len(data[0][0])
            num_outputs = len(y)
            mx_context = context[0]

            if len(self.key_memory) == 0:
                self.key_memory = nd.empty(0, ctx=mx.cpu())
                self.value_memory = []
                self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu())

            ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)]

            max_inds = [nd.max(ind[i]) for i in range(num_pus)]
            if any(max_inds):
                to_store_values = []
408
                for i in range(num_inputs):
409
410
411
412
413
414
415
416
417
418
                    tmp_values = []
                    for j in range(0, num_pus):
                        if max_inds[j]:
                            if isinstance(tmp_values, list):
                                tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j])
                            else:
                                tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0)
                    to_store_values.append(tmp_values)

                to_store_labels = []
419
                for i in range(num_outputs):
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
                    tmp_labels = []
                    for j in range(0, num_pus):
                        if max_inds[j]:
                            if isinstance(tmp_labels, list):
                                tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j])
                            else:
                                tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0)
                    to_store_labels.append(tmp_labels)

                to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs])

                if self.key_memory.shape[0] == 0:
                    self.key_memory = to_store_keys.as_in_context(mx.cpu())
                    for i in range(num_inputs):
                        self.value_memory.append(to_store_values[i].as_in_context(mx.cpu()))
                    for i in range(num_outputs):
                        self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu()))
                elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples:
                    num_to_store = to_store_keys.shape[0]
                    self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0)
                    for i in range(num_inputs):
                        self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0)
                    for i in range(num_outputs):
                        self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
                else:
                    self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0)
                    for i in range(num_inputs):
                        self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0)
                    for i in range(num_outputs):
                        self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
450

451
    def sample_memory(self, batch_size):
452
        num_stored_samples = self.key_memory.shape[0]
453
        if self.replay_batch_size == -1:
454
            sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu())
455
        else:
456
            sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu())
457
458
459
460

        num_outputs = len(self.label_memory)

        sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind]
461
        sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)]
462
463

        return sample_batches
464

465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
    def sample_neighbours(self, data, query_network):
        num_stored_samples = self.key_memory.shape[0]
        batch_size = data[0].shape[0]

        query = query_network(*data).as_in_context(mx.cpu())

        vec1 = nd.repeat(query, repeats=num_stored_samples, axis=0)
        vec2 = nd.tile(self.key_memory, reps=(batch_size, 1))
        diff = nd.subtract(vec1, vec2)
        sq = nd.square(diff)
        batch_sum = nd.sum(sq, exclude=1, axis=0)
        sqrt = nd.sqrt(batch_sum)

        dist = nd.reshape(sqrt, shape=(batch_size, num_stored_samples))

        sample_ind = nd.topk(dist, k=self.k, axis=1, ret_typ="indices")
        num_outputs = len(self.label_memory)

        sample_labels = [self.label_memory[i][sample_ind] for i in range(num_outputs)]
        sample_batches = [[self.value_memory[j][sample_ind] for j in range(len(self.value_memory))], sample_labels]

        return sample_batches

488
    def get_query_network(self, context):
489
490
491
492
493
494
        lastEpoch = 0
        for file in os.listdir(self.query_net_dir):
            if self.query_net_prefix in file and ".json" in file:
                symbolFile = file

            if self.query_net_prefix in file and ".param" in file:
495
                epochStr = file.replace(".params", "").replace(self.query_net_prefix, "")
496
497
498
499
500
                epoch = int(epochStr)
                if epoch >= lastEpoch:
                    lastEpoch = epoch
                    weightFile = file

501
502
503
504
505
506
507
508
        inputNames = []
        if self.query_net_num_inputs == 1:
            inputNames.append("data")
        else:
            for i in range(self.query_net_num_inputs):
                inputNames.append("data" + str(i))
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
509
            net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0])
510
        net.hybridize()
511
        return net
512
513
    
    def save_memory(self, path):
514
        mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)]
515
516
        mem_dict = {entry[0]:entry[1] for entry in mem_arr}
        nd.save(path, mem_dict)
517

518
519
520
521
522
523
524
525
526
527
528
    def load_memory(self, path):
        mem_dict = nd.load(path)
        self.value_memory = []
        self.label_memory = []
        for key in sorted(mem_dict.keys()):
            if key == "keys":
                self.key_memory = mem_dict[key]
            elif key.startswith("values_"):
                self.value_memory.append(mem_dict[key])
            elif key.startswith("labels_"):
                self.label_memory.append(mem_dict[key])
lr119628's avatar
lr119628 committed
529
<#if tc.containsAdaNet()>
lr119628's avatar
lr119628 committed
530
# Generation of the artificial blocks for the Streams below
lr119628's avatar
lr119628 committed
531
532
533
from mxnet.gluon import nn, HybridBlock
from numpy import log, product
from mxnet.ndarray import zeros
534
<#list tc.architecture.networkInstructions as networkInstruction>
535
<#if networkInstruction.body.containsAdaNet()>
lr119628's avatar
lr119628 committed
536
${tc.include(networkInstruction.body, "ADANET_CONSTRUCTION")}
537
#class Model(gluon.HybridBlock): THIS IS THE ORIGINAL NAME, MUST BE RENAMED IN THE OTHER PARTS
lr119628's avatar
lr119628 committed
538
class Net_${networkInstruction?index}(gluon.HybridBlock):
lr119628's avatar
lr119628 committed
539
540
541
542
543
544
545
546
    def __init__(self,operations:dict,**kwargs):
        super(Net_${networkInstruction?index},self).__init__(**kwargs)
        self.AdaNet = True
        self.op_names = []

        with self.name_scope():
            if operations is None:
                operations={'dummy':nn.Dense(units = 10)}
547
            self.data_shape = <#list networkInstruction.body.getAdaLayer().get().outputTypes as type>(${tc.join(type.dimensions, ",")})</#list>
lr119628's avatar
lr119628 committed
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
            self.classes = prod(list(self.data_shape))

            for name,operation in operations.items():
                self.__setattr__(name,operation)
                self.op_names.append(name)
            self.out = nn.Dense(units=self.classes,activation=None,flatten=False)

    def hybrid_forward(self,F,x):
        res_list = []
        for name in self.op_names:
            res_list.append(self.__getattribute__(name)(x))
        if not res_list:
            res_list = [F.identity(x)]
        res = tuple(res_list)
        y = F.concat(*res,dim=1)
        y = self.out(y)
        y = F.reshape(y,shape = self.data_shape)
        return y

    def get_candidate_complexity(self):
        mean_complexity = zeros(len(self.op_names))
        for i, name in enumerate(self.op_names):
            mean_complexity[i] = self.candidate_complexities[name]
        return mean_complexity

lr119628's avatar
lr119628 committed
573
class DataClass_${networkInstruction?index}:
lr119628's avatar
lr119628 committed
574
575
576
577
578
579
    """
    the whole model with its operations
    """
    def __init__(self, **kwargs):
        self.op_names = []  # list that holds the name of the added operations
        self.candidate_complexities = {}
lr119628's avatar
lr119628 committed
580
        self.name_ = 'Net_${networkInstruction?index}'
lr119628's avatar
lr119628 committed
581
582
583
584
        self.AdaNet = True
        self.Builder = Builder
        self.CandidateHull = CandidateHull
        self.BuildingBlock = BuildingBlock
lr119628's avatar
lr119628 committed
585
586
        self.output_shape = self.CandidateHull(name='getOutputShape',stack=0).output_shape
        self.model_template = Net_${networkInstruction?index}
587
</#if>
588
</#list>
lr119628's avatar
lr119628 committed
589
<#else>
590
<#list tc.architecture.networkInstructions as networkInstruction>
591
#Stream ${networkInstruction?index}
592
593
<#list networkInstruction.body.episodicSubNetworks as elements>
class EpisodicSubNet_${elements?index}(gluon.HybridBlock):
594
<#if elements?index == 0 >
595
596
    def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs):
        super(EpisodicSubNet_${elements?index}, self).__init__(**kwargs)
lr119628's avatar
lr119628 committed
597
        self.AdaNet=False
598
599
600
601
602
603
604
        with self.name_scope():
${tc.include(networkInstruction.body, elements?index, "ARCHITECTURE_DEFINITION")}
    
            pass
    
    def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(networkInstruction.body, false), ", ")}):
${tc.include(networkInstruction.body, elements?index, "FORWARD_FUNCTION")}
605
        return [[${tc.join(tc.getSubnetOutputNames(elements), ", ")}]]
606
<#else>
607
608
    def __init__(self, mx_context=None, **kwargs):
        super(EpisodicSubNet_${elements?index}, self).__init__(**kwargs)
609
610
611
612
613
        with self.name_scope():
${tc.include(networkInstruction.body, elements?index, "ARCHITECTURE_DEFINITION")}
    
            pass
    
614
    def hybrid_forward(self, F, *args):
615
${tc.include(networkInstruction.body, elements?index, "FORWARD_FUNCTION")}
616
617
618
        retNames = [${tc.join(tc.getSubnetOutputNames(elements), ", ")}]
        ret = []
        for elem in retNames:
619
            if isinstance(elem, list) and len(elem) >= 2:
620
621
622
623
624
                for elem2 in elem: 
                    ret.append(elem2)
            else:
                ret.append(elem)
        return [ret, [${tc.getSubnetInputNames(elements)[0]}full_, ind_${tc.join(tc.getSubnetInputNames(elements), ", ")}]]
625
626
627
628
</#if>

</#list>

629
class Net_${networkInstruction?index}(gluon.HybridBlock):
630
    def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs):
631
        super(Net_${networkInstruction?index}, self).__init__(**kwargs)
Nicola Gatto's avatar
Nicola Gatto committed
632
        with self.name_scope():
633
634
<#if networkInstruction.body.episodicSubNetworks?has_content>
<#list networkInstruction.body.episodicSubNetworks as elements>
635
<#if elements?index == 0>
636
            self.episodicsubnet0_ = EpisodicSubNet_${elements?index}(data_mean, data_std, mx_context)
637

638
            self.episodic_sub_nets = []
639
640

<#else>
641
642
            self.episodic_sub_nets.append(EpisodicSubNet_${elements?index}(mx_context=mx_context))
            self.register_child(self.episodic_sub_nets[${elements?index - 1}])
643
644
645
646

</#if>
</#list>
<#else>
647
${tc.include(networkInstruction.body, "ARCHITECTURE_DEFINITION")}
648
</#if>    
649
            pass
Nicola Gatto's avatar
Nicola Gatto committed
650

651

652
    def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(networkInstruction.body, false), ", ")}):
653
654
<#if networkInstruction.body.episodicSubNetworks?has_content>
<#list networkInstruction.body.episodicSubNetworks as elements>
655
<#if elements?index == 0>
656
        episodicsubnet${elements?index}_ = self.episodicsubnet${elements?index}_(${tc.join(tc.getStreamInputNames(networkInstruction.body, false), ", ")})  
657
<#else>
658
        episodicsubnet${elements?index}_ = self.episodic_sub_nets[${elements?index-1}](*episodicsubnet${elements?index - 1}_[0])
659
660
</#if>
</#list>
661
        return [episodicsubnet${networkInstruction.body.episodicSubNetworks?size - 1}_[0], [<#list networkInstruction.body.episodicSubNetworks as elements><#if elements?index != 0>episodicsubnet${elements?index}_[1], </#if></#list>]]
662
<#else>
663
${tc.include(networkInstruction.body, "FORWARD_FUNCTION")}
664
<#if tc.isAttentionNetwork() && networkInstruction.isUnroll() >
665
        return [[${tc.join(tc.getStreamOutputNames(networkInstruction.body, false), ", ")}], [attention_output_]]
666
<#else>
667
        return [[${tc.join(tc.getStreamOutputNames(networkInstruction.body, false), ", ")}]]
668
</#if>
669
</#if>
670
</#list>
lr119628's avatar
lr119628 committed
671
</#if>