Commit 303f6e12 authored by Christian Fuß's avatar Christian Fuß

small adjustment in showAttentionImage template

parent c6eb036e
Pipeline #211584 failed with stages
in 57 seconds
...@@ -3,21 +3,16 @@ ...@@ -3,21 +3,16 @@
fig = plt.figure(figsize=(15,15)) fig = plt.figure(figsize=(15,15))
max_length = len(labels)-1 max_length = len(labels)-1
ax = fig.add_subplot(max_length//3, max_length//4, 1) ax = fig.add_subplot(max_length//3, max_length//4, 1)
ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
for l in range(max_length): for l in range(max_length):
attention = attentionList[l] attention = attentionList[l]
attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1) attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
attention = mx.nd.squeeze(attention)
attention_resized = np.resize(attention.asnumpy(), (8, 8)) attention_resized = np.resize(attention.asnumpy(), (8, 8))
ax = fig.add_subplot(max_length//3, max_length//4, l+2) ax = fig.add_subplot(max_length//3, max_length//4, l+2)
if int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar()) > len(dict): if int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar()) > len(dict):
ax.set_title("<unk>") ax.set_title("<unk>")
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
break
elif dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())] == "<end>": elif dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())] == "<end>":
ax.set_title(".") ax.set_title(".")
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
...@@ -25,9 +20,8 @@ ...@@ -25,9 +20,8 @@
break break
else: else:
ax.set_title(dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())]) ax.set_title(dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())])
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent()) ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
plt.tight_layout() plt.tight_layout()
plt.savefig(target_dir + '/attention_test.png') plt.savefig(target_dir + '/attention_test.png')
......
...@@ -10,21 +10,16 @@ ...@@ -10,21 +10,16 @@
with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f: with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
dict = pickle.load(f) dict = pickle.load(f)
ax = fig.add_subplot(max_length//3, max_length//4, 1) ax = fig.add_subplot(max_length//3, max_length//4, 1)
ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
for l in range(max_length): for l in range(max_length):
attention = attentionList[l] attention = attentionList[l]
attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1) attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
attention = mx.nd.squeeze(attention)
attention_resized = np.resize(attention.asnumpy(), (8, 8)) attention_resized = np.resize(attention.asnumpy(), (8, 8))
ax = fig.add_subplot(max_length//3, max_length//4, l+2) ax = fig.add_subplot(max_length//3, max_length//4, l+2)
if int(labels[l+1][0].asscalar()) > len(dict): if int(labels[l+1][0].asscalar()) > len(dict):
ax.set_title("<unk>") ax.set_title("<unk>")
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
break
elif dict[int(labels[l+1][0].asscalar())] == "<end>": elif dict[int(labels[l+1][0].asscalar())] == "<end>":
ax.set_title(".") ax.set_title(".")
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
...@@ -32,13 +27,12 @@ ...@@ -32,13 +27,12 @@
break break
else: else:
ax.set_title(dict[int(labels[l+1][0].asscalar())]) ax.set_title(dict[int(labels[l+1][0].asscalar())])
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent()) ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
plt.tight_layout() plt.tight_layout()
target_dir = 'target/attention_images' target_dir = 'target/attention_images'
if not os.path.exists(target_dir): if not os.path.exists(target_dir):
os.makedirs(target_dir) os.makedirs(target_dir)
plt.savefig(target_dir + '/attention_train.png') plt.savefig(target_dir + '/attention_train.png')
plt.close() plt.close()
\ No newline at end of file
...@@ -352,21 +352,16 @@ class CNNSupervisedTrainer_Alexnet: ...@@ -352,21 +352,16 @@ class CNNSupervisedTrainer_Alexnet:
with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f: with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
dict = pickle.load(f) dict = pickle.load(f)
ax = fig.add_subplot(max_length//3, max_length//4, 1) ax = fig.add_subplot(max_length//3, max_length//4, 1)
ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
for l in range(max_length): for l in range(max_length):
attention = attentionList[l] attention = attentionList[l]
attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1) attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
attention = mx.nd.squeeze(attention)
attention_resized = np.resize(attention.asnumpy(), (8, 8)) attention_resized = np.resize(attention.asnumpy(), (8, 8))
ax = fig.add_subplot(max_length//3, max_length//4, l+2) ax = fig.add_subplot(max_length//3, max_length//4, l+2)
if int(labels[l+1][0].asscalar()) > len(dict): if int(labels[l+1][0].asscalar()) > len(dict):
ax.set_title("<unk>") ax.set_title("<unk>")
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
break
elif dict[int(labels[l+1][0].asscalar())] == "<end>": elif dict[int(labels[l+1][0].asscalar())] == "<end>":
ax.set_title(".") ax.set_title(".")
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
...@@ -374,14 +369,13 @@ class CNNSupervisedTrainer_Alexnet: ...@@ -374,14 +369,13 @@ class CNNSupervisedTrainer_Alexnet:
break break
else: else:
ax.set_title(dict[int(labels[l+1][0].asscalar())]) ax.set_title(dict[int(labels[l+1][0].asscalar())])
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent()) ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
plt.tight_layout() plt.tight_layout()
target_dir = 'target/attention_images' target_dir = 'target/attention_images'
if not os.path.exists(target_dir): if not os.path.exists(target_dir):
os.makedirs(target_dir) os.makedirs(target_dir)
plt.savefig(target_dir + '/attention_train.png') plt.savefig(target_dir + '/attention_train.png')
plt.close() plt.close()
...@@ -420,21 +414,16 @@ class CNNSupervisedTrainer_Alexnet: ...@@ -420,21 +414,16 @@ class CNNSupervisedTrainer_Alexnet:
fig = plt.figure(figsize=(15,15)) fig = plt.figure(figsize=(15,15))
max_length = len(labels)-1 max_length = len(labels)-1
ax = fig.add_subplot(max_length//3, max_length//4, 1) ax = fig.add_subplot(max_length//3, max_length//4, 1)
ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
for l in range(max_length): for l in range(max_length):
attention = attentionList[l] attention = attentionList[l]
attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1) attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
attention = mx.nd.squeeze(attention)
attention_resized = np.resize(attention.asnumpy(), (8, 8)) attention_resized = np.resize(attention.asnumpy(), (8, 8))
ax = fig.add_subplot(max_length//3, max_length//4, l+2) ax = fig.add_subplot(max_length//3, max_length//4, l+2)
if int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar()) > len(dict): if int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar()) > len(dict):
ax.set_title("<unk>") ax.set_title("<unk>")
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
break
elif dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())] == "<end>": elif dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())] == "<end>":
ax.set_title(".") ax.set_title(".")
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
...@@ -442,9 +431,8 @@ class CNNSupervisedTrainer_Alexnet: ...@@ -442,9 +431,8 @@ class CNNSupervisedTrainer_Alexnet:
break break
else: else:
ax.set_title(dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())]) ax.set_title(dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())])
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent()) ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
plt.tight_layout() plt.tight_layout()
plt.savefig(target_dir + '/attention_test.png') plt.savefig(target_dir + '/attention_test.png')
......
...@@ -352,21 +352,16 @@ class CNNSupervisedTrainer_CifarClassifierNetwork: ...@@ -352,21 +352,16 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f: with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
dict = pickle.load(f) dict = pickle.load(f)
ax = fig.add_subplot(max_length//3, max_length//4, 1) ax = fig.add_subplot(max_length//3, max_length//4, 1)
ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
for l in range(max_length): for l in range(max_length):
attention = attentionList[l] attention = attentionList[l]
attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1) attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
attention = mx.nd.squeeze(attention)
attention_resized = np.resize(attention.asnumpy(), (8, 8)) attention_resized = np.resize(attention.asnumpy(), (8, 8))
ax = fig.add_subplot(max_length//3, max_length//4, l+2) ax = fig.add_subplot(max_length//3, max_length//4, l+2)
if int(labels[l+1][0].asscalar()) > len(dict): if int(labels[l+1][0].asscalar()) > len(dict):
ax.set_title("<unk>") ax.set_title("<unk>")
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
break
elif dict[int(labels[l+1][0].asscalar())] == "<end>": elif dict[int(labels[l+1][0].asscalar())] == "<end>":
ax.set_title(".") ax.set_title(".")
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
...@@ -374,14 +369,13 @@ class CNNSupervisedTrainer_CifarClassifierNetwork: ...@@ -374,14 +369,13 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
break break
else: else:
ax.set_title(dict[int(labels[l+1][0].asscalar())]) ax.set_title(dict[int(labels[l+1][0].asscalar())])
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent()) ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
plt.tight_layout() plt.tight_layout()
target_dir = 'target/attention_images' target_dir = 'target/attention_images'
if not os.path.exists(target_dir): if not os.path.exists(target_dir):
os.makedirs(target_dir) os.makedirs(target_dir)
plt.savefig(target_dir + '/attention_train.png') plt.savefig(target_dir + '/attention_train.png')
plt.close() plt.close()
...@@ -420,21 +414,16 @@ class CNNSupervisedTrainer_CifarClassifierNetwork: ...@@ -420,21 +414,16 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
fig = plt.figure(figsize=(15,15)) fig = plt.figure(figsize=(15,15))
max_length = len(labels)-1 max_length = len(labels)-1
ax = fig.add_subplot(max_length//3, max_length//4, 1) ax = fig.add_subplot(max_length//3, max_length//4, 1)
ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
for l in range(max_length): for l in range(max_length):
attention = attentionList[l] attention = attentionList[l]
attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1) attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
attention = mx.nd.squeeze(attention)
attention_resized = np.resize(attention.asnumpy(), (8, 8)) attention_resized = np.resize(attention.asnumpy(), (8, 8))
ax = fig.add_subplot(max_length//3, max_length//4, l+2) ax = fig.add_subplot(max_length//3, max_length//4, l+2)
if int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar()) > len(dict): if int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar()) > len(dict):
ax.set_title("<unk>") ax.set_title("<unk>")
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
break
elif dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())] == "<end>": elif dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())] == "<end>":
ax.set_title(".") ax.set_title(".")
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
...@@ -442,9 +431,8 @@ class CNNSupervisedTrainer_CifarClassifierNetwork: ...@@ -442,9 +431,8 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
break break
else: else:
ax.set_title(dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())]) ax.set_title(dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())])
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent()) ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
plt.tight_layout() plt.tight_layout()
plt.savefig(target_dir + '/attention_test.png') plt.savefig(target_dir + '/attention_test.png')
......
...@@ -352,21 +352,16 @@ class CNNSupervisedTrainer_VGG16: ...@@ -352,21 +352,16 @@ class CNNSupervisedTrainer_VGG16:
with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f: with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
dict = pickle.load(f) dict = pickle.load(f)
ax = fig.add_subplot(max_length//3, max_length//4, 1) ax = fig.add_subplot(max_length//3, max_length//4, 1)
ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
for l in range(max_length): for l in range(max_length):
attention = attentionList[l] attention = attentionList[l]
attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1) attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
attention = mx.nd.squeeze(attention)
attention_resized = np.resize(attention.asnumpy(), (8, 8)) attention_resized = np.resize(attention.asnumpy(), (8, 8))
ax = fig.add_subplot(max_length//3, max_length//4, l+2) ax = fig.add_subplot(max_length//3, max_length//4, l+2)
if int(labels[l+1][0].asscalar()) > len(dict): if int(labels[l+1][0].asscalar()) > len(dict):
ax.set_title("<unk>") ax.set_title("<unk>")
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
break
elif dict[int(labels[l+1][0].asscalar())] == "<end>": elif dict[int(labels[l+1][0].asscalar())] == "<end>":
ax.set_title(".") ax.set_title(".")
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
...@@ -374,14 +369,13 @@ class CNNSupervisedTrainer_VGG16: ...@@ -374,14 +369,13 @@ class CNNSupervisedTrainer_VGG16:
break break
else: else:
ax.set_title(dict[int(labels[l+1][0].asscalar())]) ax.set_title(dict[int(labels[l+1][0].asscalar())])
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent()) ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
plt.tight_layout() plt.tight_layout()
target_dir = 'target/attention_images' target_dir = 'target/attention_images'
if not os.path.exists(target_dir): if not os.path.exists(target_dir):
os.makedirs(target_dir) os.makedirs(target_dir)
plt.savefig(target_dir + '/attention_train.png') plt.savefig(target_dir + '/attention_train.png')
plt.close() plt.close()
...@@ -420,21 +414,16 @@ class CNNSupervisedTrainer_VGG16: ...@@ -420,21 +414,16 @@ class CNNSupervisedTrainer_VGG16:
fig = plt.figure(figsize=(15,15)) fig = plt.figure(figsize=(15,15))
max_length = len(labels)-1 max_length = len(labels)-1
ax = fig.add_subplot(max_length//3, max_length//4, 1) ax = fig.add_subplot(max_length//3, max_length//4, 1)
ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
for l in range(max_length): for l in range(max_length):
attention = attentionList[l] attention = attentionList[l]
attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1) attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
attention = mx.nd.squeeze(attention)
attention_resized = np.resize(attention.asnumpy(), (8, 8)) attention_resized = np.resize(attention.asnumpy(), (8, 8))
ax = fig.add_subplot(max_length//3, max_length//4, l+2) ax = fig.add_subplot(max_length//3, max_length//4, l+2)
if int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar()) > len(dict): if int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar()) > len(dict):
ax.set_title("<unk>") ax.set_title("<unk>")
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
break
elif dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())] == "<end>": elif dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())] == "<end>":
ax.set_title(".") ax.set_title(".")
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
...@@ -442,9 +431,8 @@ class CNNSupervisedTrainer_VGG16: ...@@ -442,9 +431,8 @@ class CNNSupervisedTrainer_VGG16:
break break
else: else:
ax.set_title(dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())]) ax.set_title(dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())])
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0)) img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent()) ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
plt.tight_layout() plt.tight_layout()
plt.savefig(target_dir + '/attention_test.png') plt.savefig(target_dir + '/attention_test.png')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment