Commit 9072f8de authored by Sebastian Nickels's avatar Sebastian Nickels

Updated

parent 2bf1f19f
Pipeline #226593 failed with stages
in 1 minute and 37 seconds
......@@ -230,7 +230,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
begin_epoch = 0
if load_checkpoint:
begin_epoch = self._net_creator.load(mx_context) + 1
begin_epoch = self._net_creator.load(mx_context)
else:
if os.path.isdir(self._net_creator._model_dir_):
shutil.rmtree(self._net_creator._model_dir_)
......@@ -320,7 +320,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
loss_total += loss.sum().asscalar()
global_loss_train += float(loss.mean().asscalar())
global_loss_train += loss.sum().asscalar()
train_batches += 1
if clip_global_grad_norm:
......@@ -350,8 +350,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
tic = time.time()
if train_batches > 0:
global_loss_train /= train_batches
global_loss_train /= (train_batches * batch_size)
tic = None
......@@ -501,7 +500,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
for element in lossList:
loss = loss + element
global_loss_test += float(loss.mean().asscalar())
global_loss_test += loss.sum().asscalar()
test_batches += 1
predictions = []
......@@ -515,8 +514,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
metric.update(preds=predictions, labels=labels)
test_metric_score = metric.get()[1]
if test_batches > 0:
global_loss_test /= test_batches
global_loss_test /= (test_batches * batch_size)
logging.info("Epoch[%d] Train metric: %f, Test metric: %f, Train loss: %f, Test loss: %f" % (epoch, train_metric_score, test_metric_score, global_loss_train, global_loss_test))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment