Die Migration der Bereiche "Docker Registry" und "Artifiacts" ist fast abgeschlossen. Die letzten Daten werden im Laufe des heutigen Abend (05.08.2021) noch vollständig hochgeladen. Das Anlegen neuer Images und Artifacts funktioniert bereits wieder.

Commit c9f67c8a authored by Julian Dierkes's avatar Julian Dierkes
Browse files

Merge branch 'develop' of...

Merge branch 'develop' of git.rwth-aachen.de:monticore/EmbeddedMontiArc/generators/EMADL2CPP into develop
parents ec0f0303 d5f94b37
Pipeline #268518 failed with stage
in 1 minute and 20 seconds
This diff is collapsed.
This diff is collapsed.
......@@ -17,8 +17,8 @@
<!-- .. SE-Libraries .................................................. -->
<emadl.version>0.2.11-SNAPSHOT</emadl.version>
<CNNTrain.version>0.3.9-SNAPSHOT</CNNTrain.version>
<cnnarch-generator.version>0.0.5-SNAPSHOT</cnnarch-generator.version>
<CNNTrain.version>0.3.10-SNAPSHOT</CNNTrain.version>
<cnnarch-generator.version>0.0.6-SNAPSHOT</cnnarch-generator.version>
<cnnarch-mxnet-generator.version>0.2.17-SNAPSHOT</cnnarch-mxnet-generator.version>
<cnnarch-caffe2-generator.version>0.2.14-SNAPSHOT</cnnarch-caffe2-generator.version>
<cnnarch-gluon-generator.version>0.2.10-SNAPSHOT</cnnarch-gluon-generator.version>
......@@ -94,7 +94,7 @@
<artifactId>common-monticar</artifactId>
<version>${Common-MontiCar.version}</version>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-tensorflow-generator</artifactId>
......
......@@ -14,6 +14,7 @@ import de.monticore.lang.monticar.cnnarch._symboltable.NetworkInstructionSymbol;
import de.monticore.lang.monticar.cnnarch.generator.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch.generator.CNNTrainGenerator;
import de.monticore.lang.monticar.cnnarch.generator.DataPathConfigParser;
import de.monticore.lang.monticar.cnnarch.generator.WeightsPathConfigParser;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNTrain2Gluon;
import de.monticore.lang.monticar.cnnarch.gluongenerator.annotations.ArchitectureAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.preprocessing.PreprocessingComponentParameterAdapter;
......@@ -246,7 +247,7 @@ public class EMADLGenerator {
String b = backend.getBackendString(backend);
String trainingDataHash = "";
String testDataHash = "";
if (architecture.get().getDataPath() != null) {
if (b.equals("CAFFE2")) {
trainingDataHash = getChecksumForLargerFile(architecture.get().getDataPath() + "/train_lmdb/data.mdb");
......@@ -410,6 +411,22 @@ public class EMADLGenerator {
return dataPath;
}
protected String getWeightsPath(EMAComponentSymbol component, EMAComponentInstanceSymbol instance){
String weightsPath;
// TODO check if pretrained true, otherwise return null
Path weightsPathDefinition = Paths.get(getModelsPath(), "weights_paths.txt");
if (weightsPathDefinition.toFile().exists()) {
WeightsPathConfigParser newParserConfig = new WeightsPathConfigParser(getModelsPath() + "weights_paths.txt");
weightsPath = newParserConfig.getWeightsPath(component.getFullName());
} else {
Log.info("No weights path definition found in " + weightsPathDefinition + ": "
+ "No pretrained weights will be loaded.", "EMADLGenerator");
weightsPath = null;
}
return weightsPath;
}
protected void generateComponent(List<FileContent> fileContents,
Set<EMAComponentInstanceSymbol> allInstances,
TaggingResolver taggingResolver,
......@@ -431,7 +448,9 @@ public class EMADLGenerator {
if (architecture.isPresent()){
cnnArchGenerator.check(architecture.get());
String dPath = getDataPath(taggingResolver, EMAComponentSymbol, componentInstanceSymbol);
String wPath = getWeightsPath(EMAComponentSymbol, componentInstanceSymbol);
architecture.get().setDataPath(dPath);
architecture.get().setWeightsPath(wPath);
architecture.get().setComponentName(EMAComponentSymbol.getFullName());
generateCNN(fileContents, taggingResolver, componentInstanceSymbol, architecture.get());
if (processedArchitecture != null) {
......
......@@ -50,13 +50,42 @@ class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
if self._sparse_label:
loss = -pick(pred, label, axis=self._axis, keepdims=True)
else:
label = _reshape_like(F, label, pred)
label = gluon.loss._reshape_like(F, label, pred)
loss = -(pred * label).sum(axis=self._axis, keepdims=True)
# ignore some indices for loss, e.g. <pad> tokens in NLP applications
for i in self._ignore_indices:
loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i))
loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label))
return loss.mean(axis=self._batch_axis, exclude=True)
class DiceLoss(gluon.loss.Loss):
def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None,
batch_axis=0, **kwargs):
super(DiceLoss, self).__init__(weight, batch_axis, **kwargs)
self._axis = axis
self._sparse_label = sparse_label
self._from_logits = from_logits
def dice_loss(self, F, pred, label):
smooth = 1.
pred_y = F.argmax(pred, axis = self._axis)
intersection = pred_y * label
score = (2 * F.mean(intersection, axis=self._batch_axis, exclude=True) + smooth) \
/ (F.mean(label, axis=self._batch_axis, exclude=True) + F.mean(pred_y, axis=self._batch_axis, exclude=True) + smooth)
return - F.log(score)
def hybrid_forward(self, F, pred, label, sample_weight=None):
if not self._from_logits:
pred = F.log_softmax(pred, self._axis)
if self._sparse_label:
loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
else:
label = gluon.loss._reshape_like(F, label, pred)
loss = -F.sum(pred*label, axis=self._axis, keepdims=True)
loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
diceloss = self.dice_loss(F, pred, label)
return F.mean(loss, axis=self._batch_axis, exclude=True) + diceloss
@mx.metric.register
class BLEU(mx.metric.EvalMetric):
N = 4
......@@ -244,14 +273,17 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else []
if loss == 'softmax_cross_entropy':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'softmax_cross_entropy_ignore_indices':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel)
loss_function = SoftmaxCrossEntropyLossIgnoreIndices(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'sigmoid_binary_cross_entropy':
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
elif loss == 'cross_entropy':
loss_function = CrossEntropyLoss(sparse_label=sparseLabel)
loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'dice_loss':
loss_weight = loss_params['loss_weight'] if 'loss_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, weight=loss_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1':
......@@ -323,7 +355,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
train_test_iter.reset()
metric = mx.metric.create(eval_metric, **eval_metric_params)
for batch_i, batch in enumerate(train_test_iter):
if True:
if True:
labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
image_ = batch.data[0].as_in_context(mx_context)
......@@ -394,7 +426,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
test_iter.reset()
metric = mx.metric.create(eval_metric, **eval_metric_params)
for batch_i, batch in enumerate(test_iter):
if True:
if True:
labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
image_ = batch.data[0].as_in_context(mx_context)
......
import mxnet as mx
import logging
import os
import shutil
from CNNNet_mnist_mnistClassifier_net import Net_0
......@@ -11,6 +12,7 @@ class CNNCreator_mnist_mnistClassifier_net:
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.networks = {}
self._weights_dir_ = None
def load(self, context):
earliestLastEpoch = None
......@@ -47,6 +49,29 @@ class CNNCreator_mnist_mnistClassifier_net:
return earliestLastEpoch
def load_pretrained_weights(self, context):
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
if self._weights_dir_ is not None:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
......
......@@ -50,13 +50,89 @@ class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
if self._sparse_label:
loss = -pick(pred, label, axis=self._axis, keepdims=True)
else:
label = _reshape_like(F, label, pred)
label = gluon.loss._reshape_like(F, label, pred)
loss = -(pred * label).sum(axis=self._axis, keepdims=True)
# ignore some indices for loss, e.g. <pad> tokens in NLP applications
for i in self._ignore_indices:
loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label))
return loss.mean(axis=self._batch_axis, exclude=True)
class DiceLoss(gluon.loss.Loss):
def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None,
batch_axis=0, **kwargs):
super(DiceLoss, self).__init__(weight, batch_axis, **kwargs)
self._axis = axis
self._sparse_label = sparse_label
self._from_logits = from_logits
def dice_loss(self, F, pred, label):
smooth = 1.
pred_y = F.argmax(pred, axis = self._axis)
intersection = pred_y * label
score = (2 * F.mean(intersection, axis=self._batch_axis, exclude=True) + smooth) \
/ (F.mean(label, axis=self._batch_axis, exclude=True) + F.mean(pred_y, axis=self._batch_axis, exclude=True) + smooth)
return - F.log(score)
def hybrid_forward(self, F, pred, label, sample_weight=None):
if not self._from_logits:
pred = F.log_softmax(pred, self._axis)
if self._sparse_label:
loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
else:
label = gluon.loss._reshape_like(F, label, pred)
loss = -F.sum(pred*label, axis=self._axis, keepdims=True)
loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
diceloss = self.dice_loss(F, pred, label)
return F.mean(loss, axis=self._batch_axis, exclude=True) + diceloss
class SoftmaxCrossEntropyLossIgnoreLabel(gluon.loss.Loss):
def __init__(self, axis=-1, from_logits=False, weight=None,
batch_axis=0, ignore_label=255, **kwargs):
super(SoftmaxCrossEntropyLossIgnoreLabel, self).__init__(weight, batch_axis, **kwargs)
self._axis = axis
self._from_logits = from_logits
self._ignore_label = ignore_label
def hybrid_forward(self, F, output, label, sample_weight=None):
if not self._from_logits:
output = F.log_softmax(output, axis=self._axis)
valid_label_map = (label != self._ignore_label)
loss = -(F.pick(output, label, axis=self._axis, keepdims=True) * valid_label_map )
loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
return F.sum(loss) / F.sum(valid_label_map)
@mx.metric.register
class ACCURACY_IGNORE_LABEL(mx.metric.EvalMetric):
"""Ignores a label when computing accuracy.
"""
def __init__(self, axis=1, metric_ignore_label=255, name='accuracy',
output_names=None, label_names=None):
super(ACCURACY_IGNORE_LABEL, self).__init__(
name, axis=axis,
output_names=output_names, label_names=label_names)
self.axis = axis
self.ignore_label = metric_ignore_label
def update(self, labels, preds):
mx.metric.check_label_shapes(labels, preds)
for label, pred_label in zip(labels, preds):
if pred_label.shape != label.shape:
pred_label = mx.nd.argmax(pred_label, axis=self.axis, keepdims=True)
label = label.astype('int32')
pred_label = pred_label.astype('int32').as_in_context(label.context)
mx.metric.check_label_shapes(label, pred_label)
correct = mx.nd.sum( (label == pred_label) * (label != self.ignore_label) ).asscalar()
total = mx.nd.sum( (label != self.ignore_label) ).asscalar()
self.sum_metric += correct
self.num_inst += total
@mx.metric.register
class BLEU(mx.metric.EvalMetric):
N = 4
......@@ -192,6 +268,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
optimizer_params=(('learning_rate', 0.001),),
load_checkpoint=True,
checkpoint_period=5,
load_pretrained=False,
log_period=50,
context='gpu',
save_attention_image=False,
......@@ -236,6 +313,8 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
begin_epoch = 0
if load_checkpoint:
begin_epoch = self._net_creator.load(mx_context)
elif load_pretrained:
self._net_creator.load_pretrained_weights(mx_context)
else:
if os.path.isdir(self._net_creator._model_dir_):
shutil.rmtree(self._net_creator._model_dir_)
......@@ -253,16 +332,25 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
margin = loss_params['margin'] if 'margin' in loss_params else 1.0
sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else []
loss_axis = loss_params['loss_axis'] if 'loss_axis' in loss_params else -1
batch_axis = loss_params['batch_axis'] if 'batch_axis' in loss_params else 0
if loss == 'softmax_cross_entropy':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'softmax_cross_entropy_ignore_indices':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel)
loss_function = SoftmaxCrossEntropyLossIgnoreIndices(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'sigmoid_binary_cross_entropy':
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
elif loss == 'cross_entropy':
loss_function = CrossEntropyLoss(sparse_label=sparseLabel)
loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'dice_loss':
loss_weight = loss_params['loss_weight'] if 'loss_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, weight=loss_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'softmax_cross_entropy_ignore_label':
loss_weight = loss_params['loss_weight'] if 'loss_weight' in loss_params else None
loss_ignore_label = loss_params['loss_ignore_label'] if 'loss_ignore_label' in loss_params else None
loss_function = SoftmaxCrossEntropyLossIgnoreLabel(axis=loss_axis, ignore_label=loss_ignore_label, weight=loss_weight, batch_axis=batch_axis)
elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1':
......@@ -510,11 +598,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
predictions = []
for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
predictions.append(output_name)
metric.update(preds=predictions, labels=labels)
test_metric_score = metric.get()[1]
......
import mxnet as mx
import logging
import os
import shutil
from CNNNet_defaultGAN_defaultGANConnector_predictor import Net_0
......@@ -11,6 +12,7 @@ class CNNCreator_defaultGAN_defaultGANConnector_predictor:
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.networks = {}
self._weights_dir_ = None
def load(self, context):
earliestLastEpoch = None
......@@ -47,6 +49,29 @@ class CNNCreator_defaultGAN_defaultGANConnector_predictor:
return earliestLastEpoch
def load_pretrained_weights(self, context):
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
if self._weights_dir_ is not None:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
......
import mxnet as mx
import logging
import os
import shutil
from CNNNet_defaultGAN_defaultGANDiscriminator import Net_0
......@@ -11,6 +12,7 @@ class CNNCreator_defaultGAN_defaultGANDiscriminator:
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.networks = {}
self._weights_dir_ = None
def load(self, context):
earliestLastEpoch = None
......@@ -47,6 +49,29 @@ class CNNCreator_defaultGAN_defaultGANDiscriminator:
return earliestLastEpoch
def load_pretrained_weights(self, context):
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
if self._weights_dir_ is not None:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
......
import mxnet as mx
import logging
import os
import shutil
from CNNNet_infoGAN_infoGANConnector_predictor import Net_0
......@@ -11,6 +12,7 @@ class CNNCreator_infoGAN_infoGANConnector_predictor:
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.networks = {}
self._weights_dir_ = None
def load(self, context):
earliestLastEpoch = None
......@@ -47,6 +49,29 @@ class CNNCreator_infoGAN_infoGANConnector_predictor:
return earliestLastEpoch
def load_pretrained_weights(self, context):
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
if self._weights_dir_ is not None:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
......
import mxnet as mx
import logging
import os
import shutil
from CNNNet_infoGAN_infoGANDiscriminator import Net_0
......@@ -11,6 +12,7 @@ class CNNCreator_infoGAN_infoGANDiscriminator:
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.networks = {}
self._weights_dir_ = None
def load(self, context):
earliestLastEpoch = None
......@@ -47,6 +49,29 @@ class CNNCreator_infoGAN_infoGANDiscriminator:
return earliestLastEpoch
def load_pretrained_weights(self, context):
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
if self._weights_dir_ is not None:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std