Skip to content
Snippets Groups Projects
Commit c1c4c3ff authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Implemented bleu

parent 665375a8
No related branches found
No related tags found
1 merge request!23Added Unroll-related features and layers
......@@ -5,6 +5,8 @@ import time
import os
import shutil
import pickle
import math
import sys
from mxnet import gluon, autograd, nd
class CrossEntropyLoss(gluon.loss.Loss):
......@@ -32,6 +34,115 @@ class LogCoshLoss(gluon.loss.Loss):
loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)
@mx.metric.register
class BLEU(mx.metric.EvalMetric):
N = 4
def __init__(self, exclude=None, name='bleu', output_names=None, label_names=None):
super(BLEU, self).__init__(name=name, output_names=output_names, label_names=label_names)
self._exclude = exclude or set()
self._match_counts = [0 for _ in range(self.N)]
self._counts = [0 for _ in range(self.N)]
self._size_ref = 0
self._size_hyp = 0
def update(self, labels, preds):
labels, preds = mx.metric.check_label_shapes(labels, preds, True)
new_labels = self._convert(labels)
new_preds = self._convert(preds)
for label, pred in zip(new_labels, new_preds):
reference = [word for word in label if word not in self._exclude]
hypothesis = [word for word in pred if word not in self._exclude]
self._size_ref += len(reference)
self._size_hyp += len(hypothesis)
for n in range(self.N):
reference_ngrams = self._get_ngrams(reference, n + 1)
hypothesis_ngrams = self._get_ngrams(hypothesis, n + 1)
match_count = 0
for ngram in hypothesis_ngrams:
if ngram in reference_ngrams:
reference_ngrams.remove(ngram)
match_count += 1
self._match_counts[n] += match_count
self._counts[n] += len(hypothesis_ngrams)
def get(self):
precisions = [sys.float_info.min for n in range(self.N)]
i = 1
for n in range(self.N):
match_counts = self._match_counts[n]
counts = self._counts[n]
if counts != 0:
if match_counts == 0:
i *= 2
match_counts = 1 / i
precisions[n] = match_counts / counts
bleu = self._get_brevity_penalty() * math.exp(sum(map(math.log, precisions)) / self.N)
return (self.name, bleu)
def calculate(self):
precisions = [sys.float_info.min for n in range(self.N)]
i = 1
for n in range(self.N):
match_counts = self._match_counts[n]
counts = self._counts[n]
if counts != 0:
if match_counts == 0:
i *= 2
match_counts = 1 / i
precisions[n] = match_counts / counts
return self._get_brevity_penalty() * math.exp(sum(map(math.log, precisions)) / self.N)
def _get_brevity_penalty(self):
if self._size_hyp >= self._size_ref:
return 1
else:
return math.exp(1 - (self._size_ref / self._size_hyp))
@staticmethod
def _get_ngrams(sentence, n):
ngrams = []
if len(sentence) >= n:
for i in range(len(sentence) - n + 1):
ngrams.append(sentence[i:i+n])
return ngrams
@staticmethod
def _convert(nd_list):
if len(nd_list) == 0:
return []
new_list = [[] for _ in range(nd_list[0].shape[0])]
for element in nd_list:
for i in range(element.shape[0]):
new_list[i].append(element[i].asscalar())
return new_list
class ${tc.fileNameWithoutEnding}:
def applyBeamSearch(input, length, width, maxLength, currProb, netIndex, bestOutput):
......@@ -230,35 +341,6 @@ class ${tc.fileNameWithoutEnding}:
else:
predictions.append(output_name)
#Compute BLEU and NIST Score if data folder contains a dictionary -> NLP dataset
if(os.path.isfile('${tc.dataPath}/dict.pkl')):
with open('${tc.dataPath}/dict.pkl', 'rb') as f:
dict = pickle.load(f)
import nltk.translate.bleu_score
import nltk.translate.nist_score
prediction = []
for index in range(batch_size):
sentence = ''
for entry in predictions:
sentence += dict[int(entry[index].asscalar())] + ' '
prediction.append(sentence)
for index in range(batch_size):
sentence = ''
for batchEntry in batch.label:
sentence += dict[int(batchEntry[index].asscalar())] + ' '
print("############################")
print("label: ", sentence)
print("prediction: ", prediction[index])
BLEUscore = nltk.translate.bleu_score.sentence_bleu([sentence], prediction[index])
NISTscore = nltk.translate.nist_score.sentence_nist([sentence], prediction[index])
print("BLEU: ", BLEUscore)
print("NIST: ", NISTscore)
print("############################")
metric.update(preds=predictions, labels=labels)
train_metric_score = metric.get()[1]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment