saveAttentionImageTrain.ftl 2.35 KB
Newer Older
1
                    if save_attention_image == "True":
2 3
                        import matplotlib
                        matplotlib.use('Agg')
4 5 6 7 8 9 10
                        import matplotlib.pyplot as plt
                        logging.getLogger('matplotlib').setLevel(logging.ERROR)

                        if(os.path.isfile('src/test/resources/training_data/Show_attend_tell/dict.pkl')):
                            with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
                                dict = pickle.load(f)

11 12 13 14
                        plt.clf()
                        fig = plt.figure(figsize=(15,15))
                        max_length = len(labels)-1

15
                        ax = fig.add_subplot(max_length//3, max_length//4, 1)
16
                        ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
17

18 19
                        for l in range(max_length):
                            attention = attentionList[l]
20
                            attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
21
                            attention_resized = np.resize(attention.asnumpy(), (8, 8))
22
                            ax = fig.add_subplot(max_length//3, max_length//4, l+2)
23 24 25
                            if int(labels[l+1][0].asscalar()) > len(dict):
                                ax.set_title("<unk>")
                            elif dict[int(labels[l+1][0].asscalar())] == "<end>":
26
                                ax.set_title(".")
27
                                img = ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
28 29 30 31
                                ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
                                break
                            else:
                                ax.set_title(dict[int(labels[l+1][0].asscalar())])
32
                            img = ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
33
                            ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
34 35 36 37

                        plt.tight_layout()
                        target_dir = 'target/attention_images'
                        if not os.path.exists(target_dir):
38
                            os.makedirs(target_dir)
39 40
                        plt.savefig(target_dir + '/attention_train.png')
                        plt.close()