kmeans

parent 461bcb59
......@@ -1369,7 +1369,9 @@ class TranslationDataset(CachedDataset2):
print("kmeans")
print(kmeans)
cur_seq_difficulty = kmeans.labels
else:
curriculum_learning['cur_cluster'] = 0
else:
raise NotImplementedError("This difficulty mode is not implemented.")
# norm it between 0 and 1
if not 'norm' in curriculum_learing:
......@@ -1432,7 +1434,13 @@ class TranslationDataset(CachedDataset2):
self._num_seqs = len(self._seq_order)
return False
self._seq_order = [i for i in range(len(self._data[self._main_data_key])) if (self.seq_difficulty[i] <= competence)]
if curriculum_learning['difficulty'] == 'kmeans_only_source':
self._seq_order = [i for i in range(len(self._data[self._main_data_key])) if
(self.seq_difficulty[i] == curriculum_learning['cur_cluster'])]
curriculum_learning['cur_cluster'] = curriculum_learning['cur_cluster'] + 1) % curriculum_learning['number_of_clusters']
else:
self._seq_order = [i for i in range(len(self._data[self._main_data_key])) if (self.seq_difficulty[i] <= competence)]
self._num_seqs = len(self._seq_order)
# print("num_seqs should be less thatn 6million")
# print(self._num_seqs)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment