CNNCreator_RosCriticNetwork.py 9.46 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1
2
3
import mxnet as mx
import logging
import os
Julian Treiber's avatar
Julian Treiber committed
4
import shutil
5
import warnings
6
import inspect
Christian Fuß's avatar
Christian Fuß committed
7

Nicola Gatto's avatar
Nicola Gatto committed
8
from CNNNet_RosCriticNetwork import Net_0
Nicola Gatto's avatar
Nicola Gatto committed
9
10
11
12
13
14
15

class CNNCreator_RosCriticNetwork:
    _model_dir_ = "model/RosCriticNetwork/"
    _model_prefix_ = "model"

    def __init__(self):
        self.weight_initializer = mx.init.Normal()
Nicola Gatto's avatar
Nicola Gatto committed
16
        self.networks = {}
Julian Treiber's avatar
Julian Treiber committed
17
        self._weights_dir_ = None
Nicola Gatto's avatar
Nicola Gatto committed
18
19

    def load(self, context):
Nicola Gatto's avatar
Nicola Gatto committed
20
21
22
23
24
        earliestLastEpoch = None

        for i, network in self.networks.items():
            lastEpoch = 0
            param_file = None
25
26
27
28
            if hasattr(network, 'episodic_sub_nets'):
                num_episodic_sub_nets = len(network.episodic_sub_nets)
                lastMemEpoch = [0]*num_episodic_sub_nets
                mem_files = [None]*num_episodic_sub_nets
Nicola Gatto's avatar
Nicola Gatto committed
29
30
31
32
33
34
35
36
37
38

            try:
                os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
            except OSError:
                pass
            try:
                os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-symbol.json")
            except OSError:
                pass

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
            if hasattr(network, 'episodic_sub_nets'):
                try:
                    os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params")
                except OSError:
                    pass
                try:
                    os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json")
                except OSError:
                    pass

                for j in range(len(network.episodic_sub_nets)):
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000")
                    except OSError:
                        pass

Nicola Gatto's avatar
Nicola Gatto committed
79
80
            if os.path.isdir(self._model_dir_):
                for file in os.listdir(self._model_dir_):
81
82
                    if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
                        epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "")
Nicola Gatto's avatar
Nicola Gatto committed
83
                        epoch = int(epochStr)
84
                        if epoch >= lastEpoch:
Nicola Gatto's avatar
Nicola Gatto committed
85
86
                            lastEpoch = epoch
                            param_file = file
87
88
89
90
91
92
93
94
95
                    elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
                        relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-")
                        memSubNet = int(relMemPathInfo[0])
                        memEpochStr = relMemPathInfo[1]
                        memEpoch = int(memEpochStr)
                        if memEpoch >= lastMemEpoch[memSubNet-1]:
                            lastMemEpoch[memSubNet-1] = memEpoch
                            mem_files[memSubNet-1] = file

Nicola Gatto's avatar
Nicola Gatto committed
96
97
98
99
100
            if param_file is None:
                earliestLastEpoch = 0
            else:
                logging.info("Loading checkpoint: " + param_file)
                network.load_parameters(self._model_dir_ + param_file)
101
102
103
104
105
106
                if hasattr(network, 'episodic_sub_nets'):
                    for j, sub_net in enumerate(network.episodic_sub_nets):
                        if mem_files[j] != None:
                            logging.info("Loading Replay Memory: " + mem_files[j])
                            mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1]
                            mem_layer.load_memory(self._model_dir_ + mem_files[j])
Nicola Gatto's avatar
Nicola Gatto committed
107

108
109
                if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch:
                    earliestLastEpoch = lastEpoch + 1
Nicola Gatto's avatar
Nicola Gatto committed
110
111

        return earliestLastEpoch
Nicola Gatto's avatar
Nicola Gatto committed
112

Julian Treiber's avatar
Julian Treiber committed
113
114
115
116
117
118
119
    def load_pretrained_weights(self, context):
        if os.path.isdir(self._model_dir_):
            shutil.rmtree(self._model_dir_)
        if self._weights_dir_ is not None:
            for i, network in self.networks.items():
                # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
                param_file = None
120
121
122
123
124
                if hasattr(network, 'episodic_sub_nets'):
                    num_episodic_sub_nets = len(network.episodic_sub_nets)
                    lastMemEpoch = [0] * num_episodic_sub_nets
                    mem_files = [None] * num_episodic_sub_nets

Julian Treiber's avatar
Julian Treiber committed
125
126
127
128
129
                if os.path.isdir(self._weights_dir_):
                    lastEpoch = 0

                    for file in os.listdir(self._weights_dir_):

130
                        if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
Julian Treiber's avatar
Julian Treiber committed
131
132
                            epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
                            epoch = int(epochStr)
133
                            if epoch >= lastEpoch:
Julian Treiber's avatar
Julian Treiber committed
134
135
                                lastEpoch = epoch
                                param_file = file
136
137
138
139
140
141
142
143
144
                        elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
                            relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-")
                            memSubNet = int(relMemPathInfo[0])
                            memEpochStr = relMemPathInfo[1]
                            memEpoch = int(memEpochStr)
                            if memEpoch >= lastMemEpoch[memSubNet-1]:
                                lastMemEpoch[memSubNet-1] = memEpoch
                                mem_files[memSubNet-1] = file

Julian Treiber's avatar
Julian Treiber committed
145
146
                    logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
                    network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
147
148
149
150
151
152
153
154
155
                    if hasattr(network, 'episodic_sub_nets'):
                        assert lastEpoch == lastMemEpoch
                        for j, sub_net in enumerate(network.episodic_sub_nets):
                            if mem_files[j] != None:
                                logging.info("Loading pretrained Replay Memory: " + mem_files[j])
                                mem_layer = \
                                [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if
                                 param[0].startswith("memory")][0][1]
                                mem_layer.load_memory(self._model_dir_ + mem_files[j])
Julian Treiber's avatar
Julian Treiber committed
156
157
                else:
                    logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
158

Nicola Gatto's avatar
Nicola Gatto committed
159
    def construct(self, context, data_mean=None, data_std=None):
160
        self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="")
161
162
163
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context)
Nicola Gatto's avatar
Nicola Gatto committed
164
        self.networks[0].hybridize()
165
        self.networks[0](mx.nd.zeros((1, 8,), ctx=context[0]), mx.nd.zeros((1, 3,), ctx=context[0]))
Nicola Gatto's avatar
Nicola Gatto committed
166
167
168
169

        if not os.path.exists(self._model_dir_):
            os.makedirs(self._model_dir_)

Nicola Gatto's avatar
Nicola Gatto committed
170
171
        for i, network in self.networks.items():
            network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
Julian Dierkes's avatar
Julian Dierkes committed
172

173
174
175
    def setWeightInitializer(self, initializer):
        self.weight_initializer = initializer

Julian Dierkes's avatar
Julian Dierkes committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    def getInputs(self):
        inputs = {}
        input_dimensions = (8,)
        input_domains = (float,float('-inf'),float('inf'),)
        inputs["state_"] = input_domains + (input_dimensions,)
        input_dimensions = (3,)
        input_domains = (float,-1.0,1.0,)
        inputs["action_"] = input_domains + (input_dimensions,)
        return inputs

    def getOutputs(self):
        outputs = {}
        output_dimensions = (1,1,1,)
        output_domains = (float,float('-inf'),float('inf'),)
        outputs["qvalues_"] = output_domains + (output_dimensions,)
        return outputs