Aufgrund von Umarbeiten des s3 Storage wird es in GitLab, in nächster Zeit, mögliche Performance-Einbußen geben. Näheres dazu unter: https://maintenance.itc.rwth-aachen.de/ticket/status/messages/43/show_ticket/6670

Commit 7917840b authored by Sebastian N.'s avatar Sebastian N.

Fixed tests


Former-commit-id: 39783063
parent becfc971
......@@ -44,7 +44,7 @@ public:
MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
size = 1;
for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
assert(out_predictions_.size() == 1 || size == out_predictions_.size());
assert(size == out_predictions_.size());
MXPredGetOutput(handle, output_index, &(out_predictions_[0]), out_predictions_.size());
}
......
......@@ -189,6 +189,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
context='gpu',
checkpoint_period=5,
save_attention_image=False,
use_teacher_forcing=False,
normalize=True):
if context == 'gpu':
mx_context = mx.gpu()
......@@ -241,12 +242,12 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
margin = loss_params['margin'] if 'margin' in loss_params else 1.0
sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
#if loss == 'softmax_cross_entropy':
# fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
# loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else []
if loss == 'softmax_cross_entropy':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
ignore_indices = [2]
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
if loss == 'softmax_cross_entropy_ignore_indices':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel)
elif loss == 'sigmoid_binary_cross_entropy':
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
......@@ -288,6 +289,8 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
predictions_ = mx.nd.zeros((train_batch_size, 10,), ctx=mx_context)
nd.waitall()
lossList = []
predictions_ = self._networks[0](image_)
......@@ -321,7 +324,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
train_test_iter.reset()
metric = mx.metric.create(eval_metric, **eval_metric_params)
for batch_i, batch in enumerate(train_test_iter):
if True:
if True:
labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
image_ = batch.data[0].as_in_context(mx_context)
......@@ -329,6 +332,8 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
predictions_ = mx.nd.zeros((test_batch_size, 10,), ctx=mx_context)
nd.waitall()
outputs = []
attentionList=[]
predictions_ = self._networks[0](image_)
......@@ -341,22 +346,32 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
logging.getLogger('matplotlib').setLevel(logging.ERROR)
plt.clf()
fig = plt.figure(figsize=(10,10))
fig = plt.figure(figsize=(15,15))
max_length = len(labels)-1
if(os.path.isfile('src/test/resources/training_data/Show_attend_tell/dict.pkl')):
with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
dict = pickle.load(f)
ax = fig.add_subplot(max_length//3, max_length//4, 1)
ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
for l in range(max_length):
attention = attentionList[l]
attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1)
attention = mx.nd.squeeze(attention)
attention_resized = np.resize(attention.asnumpy(), (8, 8))
ax = fig.add_subplot(max_length//3, max_length//4, l+1)
ax.set_title(dict[int(labels[l+1][0].asscalar())])
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
ax = fig.add_subplot(max_length//3, max_length//4, l+2)
if dict[int(labels[l+1][0].asscalar())] == "<end>":
ax.set_title(".")
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
break
else:
ax.set_title(dict[int(labels[l+1][0].asscalar())])
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
plt.tight_layout()
......@@ -379,7 +394,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
test_iter.reset()
metric = mx.metric.create(eval_metric, **eval_metric_params)
for batch_i, batch in enumerate(test_iter):
if True:
if True:
labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
image_ = batch.data[0].as_in_context(mx_context)
......@@ -387,6 +402,8 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
predictions_ = mx.nd.zeros((test_batch_size, 10,), ctx=mx_context)
nd.waitall()
outputs = []
attentionList=[]
predictions_ = self._networks[0](image_)
......@@ -396,19 +413,28 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
if save_attention_image == "True":
plt.clf()
fig = plt.figure(figsize=(10,10))
fig = plt.figure(figsize=(15,15))
max_length = len(labels)-1
ax = fig.add_subplot(max_length//3, max_length//4, 1)
ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
for l in range(max_length):
attention = attentionList[l]
attention = mx.nd.slice_axis(attention, axis=2, begin=0, end=1)
attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1)
attention = mx.nd.squeeze(attention)
attention_resized = np.resize(attention.asnumpy(), (8, 8))
ax = fig.add_subplot(max_length//3, max_length//4, l+1)
ax.set_title(dict[int(mx.nd.slice_axis(mx.nd.argmax(outputs[l+1], axis=1), axis=0, begin=0, end=1).asscalar())])
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
ax = fig.add_subplot(max_length//3, max_length//4, l+2)
if dict[int(mx.nd.slice_axis(mx.nd.argmax(outputs[l+1], axis=1), axis=0, begin=0, end=1).asscalar())] == "<end>":
ax.set_title(".")
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
break
else:
ax.set_title(dict[int(mx.nd.slice_axis(mx.nd.argmax(outputs[l+1], axis=1), axis=0, begin=0, end=1).asscalar())])
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
plt.tight_layout()
......
......@@ -44,7 +44,7 @@ public:
MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
size = 1;
for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
assert(out_qvalues_.size() == 1 || size == out_qvalues_.size());
assert(size == out_qvalues_.size());
MXPredGetOutput(handle, output_index, &(out_qvalues_[0]), out_qvalues_.size());
}
......
......@@ -44,7 +44,7 @@ public:
MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
size = 1;
for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
assert(out_action_.size() == 1 || size == out_action_.size());
assert(size == out_action_.size());
MXPredGetOutput(handle, output_index, &(out_action_[0]), out_action_.size());
}
......
......@@ -44,7 +44,7 @@ public:
MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
size = 1;
for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
assert(out_qvalues_.size() == 1 || size == out_qvalues_.size());
assert(size == out_qvalues_.size());
MXPredGetOutput(handle, output_index, &(out_qvalues_[0]), out_qvalues_.size());
}
......
......@@ -44,7 +44,7 @@ public:
MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
size = 1;
for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
assert(out_commands_.size() == 1 || size == out_commands_.size());
assert(size == out_commands_.size());
MXPredGetOutput(handle, output_index, &(out_commands_[0]), out_commands_.size());
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment