...
 
Commits (3)
......@@ -22,6 +22,7 @@ import typing
from random import Random
import math
from sklearn.cluster import KMeans
import h5py
class LmDataset(CachedDataset2):
"""
......@@ -1360,15 +1361,27 @@ class TranslationDataset(CachedDataset2):
f1 = f.readlines()
f2 = [float(x) for x in f1]
cur_seq_difficulty = f2
assert self._num_seqs == len(cur_seq_difficulty), "We expect the train_data and the score_file to have the" \
" same length, but {} vs {}".format(self._num_seqs,
len(cur_seq_difficulty))
if 'reverse' in curriculum_learing and curriculum_learing['reverse']:
curriculum_learing['reverse'] = False
else:
curriculum_learing['reverse'] = True
elif mode == 'kmeans_only_source':
kmeans = KMeans(n_clusters=curriculum_learing['number_of_clusters'], random_state=0).fit(self._data[self._main_data_key])
print("kmeans_only_source computes difficulty")
# needs to be an hdf file
source_embed = h5py.File(curriculum_learing['source_embed_file'], 'r')
source_embed = source_embed.get('inputs').value # is now an nd-array
kmeans = KMeans(n_clusters=curriculum_learing['number_of_clusters'], random_state=0).fit(source_embed)
print("kmeans")
print(kmeans)
cur_seq_difficulty = kmeans.labels
# align the right values to the sequences
source_tags = h5py.File(curriculum_learing['source_embed_file'], 'r')
source_tags = source_tags.get('inputs').value # is now an nd-array
source_tags = [int(''.join([s for s in tag if s.isdigit()])) for tag in source_tags]
else:
raise NotImplementedError("This difficulty mode is not implemented.")
# norm it between 0 and 1
......@@ -1438,6 +1451,9 @@ class TranslationDataset(CachedDataset2):
self._seq_order = self._seq_order[:curriculum_learning['seq_order_len']]
self._num_seqs = len(self._seq_order)
# TODO sort them with laplace
if 'n_times_per_slice' in curriculum_learning:
self._seq_order = numpy.tile(self._seq_order, curriculum_learning['n_times_per_slice'])
self._num_seqs = len(self._seq_order)
return True
#
......
......@@ -1154,7 +1154,8 @@ class Engine(EngineBase):
print(self.train_data._num_seqs)
print("num_seqs")
self.train_data.get_seq_difficulty(self.curriculum_learning)
#self.curriculum_learning['old_seq_order'] = self.train_data._seq_order
# if 'n_times_per_slice' in self.curriculum_learning:
# #TODO increase learning rate by 3
self.dataset_batches.clear()
......