Commit 7c63e291 authored by Abdallah Atouani's avatar Abdallah Atouani
Browse files

fix tests

parent b25c3df9
Pipeline #350117 passed with stage
in 1 minute and 49 seconds
......@@ -674,8 +674,13 @@ class CNNSupervisedTrainer_Alexnet:
global_loss_test /= (test_batches * single_pu_batch_size)
test_metric_name = metric.get()[0]
test_metric_score = metric.get()[1]
metric_file = open(self._net_creator._model_dir_ + 'metric.txt', 'w')
metric_file.write(test_metric_name + " : " + str(test_metric_score))
metric_file.close()
logging.info("Epoch[%d] Train metric: %f, Test metric: %f, Train loss: %f, Test loss: %f" % (epoch, train_metric_score, test_metric_score, global_loss_train, global_loss_test))
if (epoch+1) % checkpoint_period == 0:
......
......@@ -674,8 +674,13 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
global_loss_test /= (test_batches * single_pu_batch_size)
test_metric_name = metric.get()[0]
test_metric_score = metric.get()[1]
metric_file = open(self._net_creator._model_dir_ + 'metric.txt', 'w')
metric_file.write(test_metric_name + " : " + str(test_metric_score))
metric_file.close()
logging.info("Epoch[%d] Train metric: %f, Test metric: %f, Train loss: %f, Test loss: %f" % (epoch, train_metric_score, test_metric_score, global_loss_train, global_loss_test))
if (epoch+1) % checkpoint_period == 0:
......
......@@ -795,10 +795,15 @@ class CNNSupervisedTrainer_EpisodicMemoryNetwork:
metric.update(preds=predictions, labels=[labels[j][local_adaptation_batch_i] for j in range(len(labels))])
self._networks[0].collect_params().load_dict(params[0], ctx=mx_context[0])
global_loss_test /= (test_batches)
global_loss_test /= (test_batches)
test_metric_name = metric.get()[0]
test_metric_score = metric.get()[1]
metric_file = open(self._net_creator._model_dir_ + 'metric.txt', 'w')
metric_file.write(test_metric_name + " : " + str(test_metric_score))
metric_file.close()
logging.info("Epoch[%d] Train metric: %f, Test metric: %f, Train loss: %f, Test loss: %f" % (epoch, train_metric_score, test_metric_score, global_loss_train, global_loss_test))
if (epoch+1) % checkpoint_period == 0:
......
......@@ -674,8 +674,13 @@ class CNNSupervisedTrainer_LoadNetworkTest:
global_loss_test /= (test_batches * single_pu_batch_size)
test_metric_name = metric.get()[0]
test_metric_score = metric.get()[1]
metric_file = open(self._net_creator._model_dir_ + 'metric.txt', 'w')
metric_file.write(test_metric_name + " : " + str(test_metric_score))
metric_file.close()
logging.info("Epoch[%d] Train metric: %f, Test metric: %f, Train loss: %f, Test loss: %f" % (epoch, train_metric_score, test_metric_score, global_loss_train, global_loss_test))
if (epoch+1) % checkpoint_period == 0:
......
......@@ -674,8 +674,13 @@ class CNNSupervisedTrainer_VGG16:
global_loss_test /= (test_batches * single_pu_batch_size)
test_metric_name = metric.get()[0]
test_metric_score = metric.get()[1]
metric_file = open(self._net_creator._model_dir_ + 'metric.txt', 'w')
metric_file.write(test_metric_name + " : " + str(test_metric_score))
metric_file.close()
logging.info("Epoch[%d] Train metric: %f, Test metric: %f, Train loss: %f, Test loss: %f" % (epoch, train_metric_score, test_metric_score, global_loss_train, global_loss_test))
if (epoch+1) % checkpoint_period == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment