Commit 51bb07fa authored by Sebastian Nickels's avatar Sebastian Nickels

Updated for parameters checkpoint_period, log_period and eval_train

parent 1bf1621d
Pipeline #221321 failed with stages
in 5 minutes and 31 seconds
......@@ -14,7 +14,7 @@ class CNNDataLoader_mnist_mnistClassifier_net:
def __init__(self):
self._data_dir = "data/mnist.LeNetNetwork/"
def load_data(self, train_batch_size, test_batch_size):
def load_data(self, batch_size):
train_h5, test_h5 = self.load_h5_files()
train_data = {}
......@@ -38,11 +38,7 @@ class CNNDataLoader_mnist_mnistClassifier_net:
train_iter = mx.io.NDArrayIter(data=train_data,
label=train_label,
batch_size=train_batch_size)
train_test_iter = mx.io.NDArrayIter(data=train_data,
label=train_label,
batch_size=test_batch_size)
batch_size=batch_size)
test_iter = None
......@@ -63,9 +59,9 @@ class CNNDataLoader_mnist_mnistClassifier_net:
test_iter = mx.io.NDArrayIter(data=test_data,
label=test_label,
batch_size=test_batch_size)
batch_size=batch_size)
return train_iter, train_test_iter, test_iter, data_mean, data_std, train_images, test_images
return train_iter, test_iter, data_mean, data_std, train_images, test_images
def load_data_img(self, batch_size, img_size):
train_h5, test_h5 = self.load_h5_files()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment