kmeans difficulty

parent 7566c3f5
......@@ -21,6 +21,7 @@ import re
import typing
from random import Random
import math
from sklearn.cluster import KMeans
class LmDataset(CachedDataset2):
"""
......@@ -1363,7 +1364,11 @@ class TranslationDataset(CachedDataset2):
curriculum_learing['reverse'] = False
else:
curriculum_learing['reverse'] = True
elif mode == 'kmeans_only_source':
kmeans = KMeans(n_clusters=curriculum_learing['number_of_clusters'], random_state=0).fit(self._data[self._main_data_key])
print("kmeans")
print(kmeans)
cur_seq_difficulty = kmeans.labels
else:
raise NotImplementedError("This difficulty mode is not implemented.")
# norm it between 0 and 1
......@@ -1385,6 +1390,8 @@ class TranslationDataset(CachedDataset2):
cur_seq_difficulty = idx_sorted / max_seq
print('cur_seq_difficulty_at_the_end')
print(cur_seq_difficulty)
elif curriculum_learing['norm'] == 'no_norm':
pass
else:
raise NotImplementedError("This norm mode is not implemented.")
# TODO check if reverse works fine
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment