saveAttentionImageTest.ftl 2.56 KB
Newer Older
1
                    if save_attention_image == "True":
2 3 4 5 6 7 8 9 10 11
                        if not eval_train:
                            import matplotlib
                            matplotlib.use('Agg')
                            import matplotlib.pyplot as plt
                            logging.getLogger('matplotlib').setLevel(logging.ERROR)

                            if(os.path.isfile('src/test/resources/training_data/Show_attend_tell/dict.pkl')):
                                with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
                                    dict = pickle.load(f)

12
                        plt.clf()
13
                        fig = plt.figure(figsize=(15,15))
14 15
                        max_length = len(labels)-1

16
                        ax = fig.add_subplot(max_length//3, max_length//4, 1)
17
                        ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
18

19 20
                        for l in range(max_length):
                            attention = attentionList[l]
21
                            attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
22
                            attention_resized = np.resize(attention.asnumpy(), (8, 8))
23
                            ax = fig.add_subplot(max_length//3, max_length//4, l+2)
24 25 26
                            if int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar()) > len(dict):
                                ax.set_title("<unk>")
                            elif dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())] == "<end>":
27
                                ax.set_title(".")
28
                                img = ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
29 30 31
                                ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
                                break
                            else:
32
                                ax.set_title(dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())])
33
                            img = ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
34
                            ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
35 36

                        plt.tight_layout()
37 38 39
                        target_dir = 'target/attention_images'
                        if not os.path.exists(target_dir):
                            os.makedirs(target_dir)
40 41
                        plt.savefig(target_dir + '/attention_test.png')
                        plt.close()