Skip to content
Snippets Groups Projects
Commit 96c24a78 authored by Christian Fuß's avatar Christian Fuß
Browse files

fixed batch size in Reshape layer. Adjusted generated attention images

parent 97af19e9
No related branches found
No related tags found
1 merge request!23Added Unroll-related features and layers
Pipeline #208394 failed
<#assign input = element.inputs[0]>
<#if mode == "FORWARD_FUNCTION">
${element.name} = F.reshape(${input}, shape=(${tc.join(element.shape, ",")}))
${element.name} = F.reshape(${input}, shape=(0,${tc.join(element.shape, ",")}))
</#if>
\ No newline at end of file
if save_attention_image == "True":
plt.clf()
fig = plt.figure(figsize=(10,10))
fig = plt.figure(figsize=(15,15))
max_length = len(labels)-1
ax = fig.add_subplot(max_length//3, max_length//4, 1)
ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
for l in range(max_length):
attention = attentionList[l]
attention = mx.nd.slice_axis(attention, axis=2, begin=0, end=1)
attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1)
attention = mx.nd.squeeze(attention)
attention_resized = np.resize(attention.asnumpy(), (8, 8))
ax = fig.add_subplot(max_length//3, max_length//4, l+1)
ax.set_title(dict[int(mx.nd.slice_axis(mx.nd.argmax(outputs[l+1], axis=1), axis=0, begin=0, end=1).asscalar())])
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
ax = fig.add_subplot(max_length//3, max_length//4, l+2)
if dict[int(mx.nd.slice_axis(mx.nd.argmax(outputs[l+1], axis=1), axis=0, begin=0, end=1).asscalar())] == "<end>":
ax.set_title(".")
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
break
else:
ax.set_title(dict[int(mx.nd.slice_axis(mx.nd.argmax(outputs[l+1], axis=1), axis=0, begin=0, end=1).asscalar())])
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
plt.tight_layout()
......
......@@ -3,22 +3,32 @@
logging.getLogger('matplotlib').setLevel(logging.ERROR)
plt.clf()
fig = plt.figure(figsize=(10,10))
fig = plt.figure(figsize=(15,15))
max_length = len(labels)-1
if(os.path.isfile('src/test/resources/training_data/Show_attend_tell/dict.pkl')):
with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
dict = pickle.load(f)
ax = fig.add_subplot(max_length//3, max_length//4, 1)
ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
for l in range(max_length):
attention = attentionList[l]
attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1)
attention = mx.nd.squeeze(attention)
attention_resized = np.resize(attention.asnumpy(), (8, 8))
ax = fig.add_subplot(max_length//3, max_length//4, l+1)
ax.set_title(dict[int(labels[l+1][0].asscalar())])
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
ax = fig.add_subplot(max_length//3, max_length//4, l+2)
if dict[int(labels[l+1][0].asscalar())] == "<end>":
ax.set_title(".")
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
break
else:
ax.set_title(dict[int(labels[l+1][0].asscalar())])
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
plt.tight_layout()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment