Commit cfd3bbd8 authored by Julian Treiber's avatar Julian Treiber

added weightsPath

parent f2e2fb86
......@@ -40,7 +40,7 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
supportedLayerList.add(AllPredefinedLayers.REDUCE_SUM_NAME);
supportedLayerList.add(AllPredefinedLayers.BROADCAST_ADD_NAME);
supportedLayerList.add(AllPredefinedLayers.RESHAPE_NAME);
supportedLayerList.add(AllPredefinedLayers.CROP_NAME);
// supportedLayerList.add(AllPredefinedLayers.CROP_NAME);
}
}
......@@ -2,6 +2,7 @@
import mxnet as mx
import logging
import os
import shutil
<#list tc.architecture.networkInstructions as networkInstruction>
from CNNNet_${tc.fullArchitectureName} import Net_${networkInstruction?index}
......@@ -14,8 +15,13 @@ class ${tc.fileNameWithoutEnding}:
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.networks = {}
<#if (tc.weightsPath)??>
self._weights_dir_ = "${tc.weightsPath}/"
<#else>
self._weights_dir_ = None
</#if>
def load(self, context, load_pretrained=False, pretrained_files=None):
def load(self, context):
earliestLastEpoch = None
for i, network in self.networks.items():
......@@ -30,7 +36,6 @@ class ${tc.fileNameWithoutEnding}:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-symbol.json")
except OSError:
pass
if os.path.isdir(self._model_dir_):
for file in os.listdir(self._model_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
......@@ -50,6 +55,29 @@ class ${tc.fileNameWithoutEnding}:
return earliestLastEpoch
def load_pretrained_weights(self, context):
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
if self._weights_dir_ is not None:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file)
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
<#list tc.architecture.networkInstructions as networkInstruction>
self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std)
......
......@@ -266,7 +266,9 @@ class ${tc.fileNameWithoutEnding}:
begin_epoch = 0
if load_checkpoint:
begin_epoch = self._net_creator.load(mx_context, load_pretrained=load_pretrained)
begin_epoch = self._net_creator.load(mx_context)
elif load_pretrained:
self._net_creator.load_pretrained_weights(mx_context)
else:
if os.path.isdir(self._net_creator._model_dir_):
shutil.rmtree(self._net_creator._model_dir_)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment