Commit a695e7f9 authored by Nishtha Jain's avatar Nishtha Jain
Browse files

raw cut at start pos

parent 5f424393
This diff is collapsed.
......@@ -20,6 +20,8 @@ SEED = 414325
WORD2VEC_PATH = "word_embeddings/GoogleNews-vectors-negative300.bin"
MASKED = { True:'bio',
False:'raw'
}
......@@ -41,7 +43,7 @@ DATASET_NAMES = {
('embedding','w2v','trial','random',0.2,'raw') : 'word_embeddings/embd__wv_tri_ran_0.2_r',
('embedding','cv','trial','balanced',0.2,'bio') : 'word_embeddings/embd__cv_tri_bal_0.2_b',
('embedding','w2v','trial','balanced',0.2,'bio') : 'word_embeddings/embd__wv_tri_bal_0.2_b',
('embedding','elmo','trial','balanced',0.2,'bio') : 'word_embeddings/embd__ev_tri_bal_0.2_b',
('embedding','elmo','trial','balanced',0.2,'bio') : 'word_embeddings/embd__ev_tri_bal_0.2_b', ## needed???
('model','svm','cv','trial','random',0.2,'bio') : 'models/modl_sv_cv_tri_ran_0.2_b',
('model','svm','cv','trial','random',0.2,'raw') : 'models/modl_sv_cv_tri_ran_0.2_r',
......
......@@ -12,6 +12,10 @@ def get_data_from_mongo(class_group):
data = list(collection.find({'$or':[{'title':title} for title in CLASS_GROUP[class_group]]}))
for i,cursor in enumerate(data):
data[i]['raw_old'] = data[i]['raw']
data[i]['raw'] = data[i]['raw_old'][data[i]['start_pos']:]
print("\t saving file : ",DATASET_NAMES['datasets',class_group])
dump(data, DATASET_NAMES['datasets',class_group]+'.joblib')
return(data)
......@@ -26,10 +30,6 @@ def get_distinct_field_values(field_name):
return ([x[field_name] for x in collection.find({},{field_name:1})])
def load_data(class_group, from_saved=True):
print('processing sampling.load_data ...')
......@@ -39,10 +39,11 @@ def load_data(class_group, from_saved=True):
data = load(DATASET_NAMES['datasets',class_group]+'.joblib')
return data
'''
sampling -> random, balanced(upsample, downsample, weighted)
'''
def data_selection(data, class_group, sampling, test_size, masking=True):
def data_selection(data, class_group, sampling, test_size):
print('processing sampling.data_selection ...')
if sampling == 'random' :
......
......@@ -13,7 +13,7 @@ def main(load_data_from_saved, embedding_train, model_train, predict, evaluate,
print("processing train.main ...")
data = pd.DataFrame(load_data(class_group=class_group, from_saved=load_data_from_saved))
train_set,test_set = data_selection(data, class_group, sampling, test_size, masking)
train_set,test_set = data_selection(data, class_group, sampling, test_size)
# TO-DO:
# no-model_train but embedding_train true
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment