Commit 1994fb64 authored by Nishtha Jain's avatar Nishtha Jain
Browse files

trial elmo

parent a319c0c8
......@@ -25,7 +25,7 @@ MASKED = { True:'bio',
}
DATASET_NAMES = {
# Naming convention : [dset|embd|modl]_[sv|rf|nn]_[cv|wv|tv]_[tri|med]_[ran|bal]_[test_spit]_[b|r]
# Naming convention : [dset|embd|modl]_[sv|rf|nn]_[cv|wv|tv|ev]_[tri|med]_[ran|bal]_[test_spit]_[b|r]
# trial domain
('datasets','trial') : 'datasets/dset___tri___',
......@@ -39,6 +39,7 @@ DATASET_NAMES = {
('embedding','cv','trial','balanced',0.2,'bio') : 'word_embeddings/embd__cv_tri_bal_0.2_b',
('embedding','w2v','trial','balanced',0.2,'bio') : 'word_embeddings/embd__wv_tri_bal_0.2_b',
('embedding','self_w2v','trial','balanced',0.2,'bio') : 'word_embeddings/embd__tv_tri_bal_0.2_b',
('embedding','elmo','trial','balanced',0.2,'bio') : 'word_embeddings/embd__ev_tri_bal_0.2_b',
('model','svm','cv','trial','random',0.2,'bio') : 'models/modl_sv_cv_tri_ran_0.2_b',
('model','svm','cv','trial','random',0.2,'raw') : 'models/modl_sv_cv_tri_ran_0.2_r',
......@@ -47,6 +48,7 @@ DATASET_NAMES = {
('model','svm','cv','trial','balanced',0.2,'bio') : 'models/modl_sv_cv_tri_bal_0.2_b',
('model','svm','w2v','trial','balanced',0.2,'bio') : 'models/modl_sv_wv_tri_bal_0.2_b',
('model','svm','self_w2v','trial','balanced',0.2,'bio') : 'models/modl_sv_tv_tri_bal_0.2_b',
('model','svm','elmo','trial','balanced',0.2,'bio') : 'models/modl_sv_ev_tri_bal_0.2_b',
......@@ -62,6 +64,8 @@ DATASET_NAMES = {
('embedding','w2v','medical','random',0.2,'raw') : 'word_embeddings/embd__wv_med_ran_0.2_r',
('embedding','self_w2v','medical','random',0.2,'bio') : 'word_embeddings/embd__tv_med_ran_0.2_b',
('embedding','self_w2v','medical','random',0.2,'raw') : 'word_embeddings/embd__tv_med_ran_0.2_r',
('embedding','elmo','medical','random',0.2,'bio') : 'word_embeddings/embd__ev_med_ran_0.2_b',
('embedding','elmo','medical','random',0.2,'raw') : 'word_embeddings/embd__ev_med_ran_0.2_r',
('embedding','cv','medical','balanced',0.2,'bio') : 'word_embeddings/embd__cv_med_bal_0.2_b',
('embedding','cv','medical','balanced',0.2,'raw') : 'word_embeddings/embd__cv_med_bal_0.2_r',
......@@ -69,6 +73,8 @@ DATASET_NAMES = {
('embedding','w2v','medical','balanced',0.2,'raw') : 'word_embeddings/embd__wv_med_bal_0.2_r',
('embedding','self_w2v','medical','balanced',0.2,'bio') : 'word_embeddings/embd__tv_med_bal_0.2_b',
('embedding','self_w2v','medical','balanced',0.2,'raw') : 'word_embeddings/embd__tv_med_bal_0.2_r',
('embedding','elmo','medical','balanced',0.2,'bio') : 'word_embeddings/embd__ev_med_bal_0.2_b',
('embedding','elmo','medical','balanced',0.2,'raw') : 'word_embeddings/embd__ev_med_bal_0.2_r',
('model','svm','cv','medical','random',0.2,'bio') : 'models/modl_sv_cv_med_ran_0.2_b',
('model','svm','cv','medical','random',0.2,'raw') : 'models/modl_sv_cv_med_ran_0.2_r',
......@@ -76,13 +82,17 @@ DATASET_NAMES = {
('model','svm','w2v','medical','random',0.2,'raw') : 'models/modl_sv_wv_med_ran_0.2_r',
('model','svm','self_w2v','medical','random',0.2,'bio') : 'models/modl_sv_tv_med_ran_0.2_b',
('model','svm','self_w2v','medical','random',0.2,'raw') : 'models/modl_sv_tv_med_ran_0.2_r',
('model','svm','elmo','medical','random',0.2,'bio') : 'models/modl_sv_ev_med_ran_0.2_b',
('model','svm','elmo','medical','random',0.2,'raw') : 'models/modl_sv_ev_med_ran_0.2_r',
('model','svm','cv','medical','balanced',0.2,'bio') : 'models/modl_sv_cv_med_bal_0.2_b',
('model','svm','cv','medical','balanced',0.2,'raw') : 'models/modl_sv_cv_med_bal_0.2_r',
('model','svm','w2v','medical','balanced',0.2,'bio') : 'models/modl_sv_wv_med_bal_0.2_b',
('model','svm','w2v','medical','balanced',0.2,'raw') : 'models/modl_sv_wv_med_bal_0.2_r',
('model','svm','self_w2v','medical','balanced',0.2,'bio') : 'models/modl_sv_tv_med_bal_0.2_b',
('model','svm','self_w2v','medical','balanced',0.2,'raw') : 'models/modl_sv_tv_med_bal_0.2_r'
('model','svm','self_w2v','medical','balanced',0.2,'raw') : 'models/modl_sv_tv_med_bal_0.2_r',
('model','svm','elmo','medical','balanced',0.2,'bio') : 'models/modl_sv_ev_med_bal_0.2_b',
('model','svm','elmo','medical','balanced',0.2,'raw') : 'models/modl_sv_ev_med_bal_0.2_r'
}
......@@ -96,6 +106,8 @@ PREDICTED_DATASET = {
('svm','w2v','trial','balanced',0.2,'bio') : 'predicted_datasets/pred_sv_wv_tri_bal_0.2_b',
('svm','self_w2v','trial','balanced',0.2,'raw') : 'predicted_datasets/pred_sv_tv_tri_bal_0.2_r',
('svm','self_w2v','trial','balanced',0.2,'bio') : 'predicted_datasets/pred_sv_tv_tri_bal_0.2_b',
('svm','elmo','trial','balanced',0.2,'raw') : 'predicted_datasets/pred_sv_ev_tri_bal_0.2_r',
('svm','elmo','trial','balanced',0.2,'bio') : 'predicted_datasets/pred_sv_ev_tri_bal_0.2_b',
('svm','cv','trial','random',0.2,'raw') : 'predicted_datasets/pred_sv_cv_tri_ran_0.2_r',
('svm','cv','trial','random',0.2,'bio') : 'predicted_datasets/pred_sv_cv_tri_ran_0.2_b',
......@@ -103,8 +115,8 @@ PREDICTED_DATASET = {
('svm','w2v','trial','random',0.2,'bio') : 'predicted_datasets/pred_sv_wv_tri_ran_0.2_b',
('svm','self_w2v','trial','random',0.2,'raw') : 'predicted_datasets/pred_sv_tv_tri_ran_0.2_r',
('svm','self_w2v','trial','random',0.2,'bio') : 'predicted_datasets/pred_sv_tv_tri_ran_0.2_b',
('svm','elmo','trial','random',0.2,'raw') : 'predicted_datasets/pred_sv_ev_tri_ran_0.2_r',
('svm','elmo','trial','random',0.2,'bio') : 'predicted_datasets/pred_sv_ev_tri_ran_0.2_b',
('svm','cv','medical','random',0.2,'raw') : 'predicted_datasets/pred_sv_cv_med_ran_0.2_r',
......@@ -113,13 +125,17 @@ PREDICTED_DATASET = {
('svm','w2v','medical','random',0.2,'bio') : 'predicted_datasets/pred_sv_wv_med_ran_0.2_b',
('svm','self_w2v','medical','random',0.2,'raw') : 'predicted_datasets/pred_sv_tv_med_ran_0.2_r',
('svm','self_w2v','medical','random',0.2,'bio') : 'predicted_datasets/pred_sv_tv_med_ran_0.2_b',
('svm','elmo','medical','random',0.2,'raw') : 'predicted_datasets/pred_sv_ev_med_ran_0.2_r',
('svm','elmo','medical','random',0.2,'bio') : 'predicted_datasets/pred_sv_ev_med_ran_0.2_b',
('svm','cv','medical','balanced',0.2,'raw') : 'predicted_datasets/pred_sv_cv_med_bal_0.2_r',
('svm','cv','medical','balanced',0.2,'bio') : 'predicted_datasets/pred_sv_cv_med_bal_0.2_b',
('svm','w2v','medical','balanced',0.2,'raw') : 'predicted_datasets/pred_sv_wv_med_bal_0.2_r',
('svm','w2v','medical','balanced',0.2,'bio') : 'predicted_datasets/pred_sv_wv_med_bal_0.2_b',
('svm','self_w2v','medical','balanced',0.2,'raw') : 'predicted_datasets/pred_sv_tv_med_bal_0.2_r',
('svm','self_w2v','medical','balanced',0.2,'bio') : 'predicted_datasets/pred_sv_tv_med_bal_0.2_b'
('svm','self_w2v','medical','balanced',0.2,'bio') : 'predicted_datasets/pred_sv_tv_med_bal_0.2_b',
('svm','elmo','medical','balanced',0.2,'raw') : 'predicted_datasets/pred_sv_ev_med_bal_0.2_r',
('svm','elmo','medical','balanced',0.2,'bio') : 'predicted_datasets/pred_sv_ev_med_bal_0.2_b'
}
......@@ -130,6 +146,8 @@ PLOT_NAMES = {
('tgp','svm','self_w2v','trial','balanced',0.2,'bio') : 'plots_and_graphs/plot_tgp_sv_tv_tri_bal_0.2_b',
('aod','svm','self_w2v','trial','balanced',0.2,'bio') : 'plots_and_graphs/plot_aod_sv_tv_tri_bal_0.2_b',
('tgp','svm','elmo','trial','balanced',0.2,'bio') : 'plots_and_graphs/plot_tgp_sv_ev_tri_bal_0.2_b',
('aod','svm','elmo','trial','balanced',0.2,'bio') : 'plots_and_graphs/plot_aod_sv_ev_tri_bal_0.2_b',
('tgp','svm','cv','medical','random',0.2,'raw') : 'plots_and_graphs/plot_tgp_sv_cv_med_ran_0.2_r',
('tgp','svm','cv','medical','random',0.2,'bio') : 'plots_and_graphs/plot_tgp_sv_cv_med_ran_0.2_b',
......@@ -137,6 +155,8 @@ PLOT_NAMES = {
('tgp','svm','w2v','medical','random',0.2,'bio') : 'plots_and_graphs/plot_tgp_sv_wv_med_ran_0.2_b',
('tgp','svm','self_w2v','medical','random',0.2,'raw') : 'plots_and_graphs/plot_tgp_sv_tv_med_ran_0.2_r',
('tgp','svm','self_w2v','medical','random',0.2,'bio') : 'plots_and_graphs/plot_tgp_sv_tv_med_ran_0.2_b',
('tgp','svm','elmo','medical','random',0.2,'raw') : 'plots_and_graphs/plot_tgp_sv_ev_med_ran_0.2_r',
('tgp','svm','elmo','medical','random',0.2,'bio') : 'plots_and_graphs/plot_tgp_sv_ev_med_ran_0.2_b',
('tgp','svm','cv','medical','balanced',0.2,'raw') : 'plots_and_graphs/plot_tgp_sv_cv_med_bal_0.2_r',
('tgp','svm','cv','medical','balanced',0.2,'bio') : 'plots_and_graphs/plot_tgp_sv_cv_med_bal_0.2_b',
......@@ -144,6 +164,8 @@ PLOT_NAMES = {
('tgp','svm','w2v','medical','balanced',0.2,'bio') : 'plots_and_graphs/plot_tgp_sv_wv_med_bal_0.2_b',
('tgp','svm','self_w2v','medical','balanced',0.2,'raw') : 'plots_and_graphs/plot_tgp_sv_tv_med_bal_0.2_r',
('tgp','svm','self_w2v','medical','balanced',0.2,'bio') : 'plots_and_graphs/plot_tgp_sv_tv_med_bal_0.2_b',
('tgp','svm','elmo','medical','balanced',0.2,'raw') : 'plots_and_graphs/plot_tgp_sv_ev_med_bal_0.2_r',
('tgp','svm','elmo','medical','balanced',0.2,'bio') : 'plots_and_graphs/plot_tgp_sv_ev_med_bal_0.2_b',
('aod','svm','cv','medical','random',0.2,'raw') : 'plots_and_graphs/plot_aod_sv_cv_med_ran_0.2_r',
......@@ -152,12 +174,16 @@ PLOT_NAMES = {
('aod','svm','w2v','medical','random',0.2,'bio') : 'plots_and_graphs/plot_aod_sv_wv_med_ran_0.2_b',
('aod','svm','self_w2v','medical','random',0.2,'raw') : 'plots_and_graphs/plot_aod_sv_tv_med_ran_0.2_r',
('aod','svm','self_w2v','medical','random',0.2,'bio') : 'plots_and_graphs/plot_aod_sv_tv_med_ran_0.2_b',
('aod','svm','elmo','medical','random',0.2,'raw') : 'plots_and_graphs/plot_aod_sv_ev_med_ran_0.2_r',
('aod','svm','elmo','medical','random',0.2,'bio') : 'plots_and_graphs/plot_aod_sv_ev_med_ran_0.2_b',
('aod','svm','cv','medical','balanced',0.2,'raw') : 'plots_and_graphs/plot_aod_sv_cv_med_bal_0.2_r',
('aod','svm','cv','medical','balanced',0.2,'bio') : 'plots_and_graphs/plot_aod_sv_cv_med_bal_0.2_b',
('aod','svm','w2v','medical','balanced',0.2,'raw') : 'plots_and_graphs/plot_aod_sv_wv_med_bal_0.2_r',
('aod','svm','w2v','medical','balanced',0.2,'bio') : 'plots_and_graphs/plot_aod_sv_wv_med_bal_0.2_b',
('aod','svm','self_w2v','medical','balanced',0.2,'raw') : 'plots_and_graphs/plot_aod_sv_tv_med_bal_0.2_r',
('aod','svm','self_w2v','medical','balanced',0.2,'bio') : 'plots_and_graphs/plot_aod_sv_tv_med_bal_0.2_b'
('aod','svm','self_w2v','medical','balanced',0.2,'bio') : 'plots_and_graphs/plot_aod_sv_tv_med_bal_0.2_b',
('aod','svm','elmo','medical','balanced',0.2,'raw') : 'plots_and_graphs/plot_aod_sv_ev_med_bal_0.2_r',
('aod','svm','elmo','medical','balanced',0.2,'bio') : 'plots_and_graphs/plot_aod_sv_ev_med_bal_0.2_b'
}
......@@ -30,6 +30,8 @@ def embedding_fit_transform(x_list, embedding, class_group, sampling, test_size,
trained_embedding, X = pretrained_word2vec_fit_transform(x_list)
elif embedding == 'self_w2v':
trained_embedding, X = selftrained_word2vec_fit_transform(x_list)
elif embedding == 'elmo':
trained_embedding, X = elmo_fit_transform(x_list)
dump(trained_embedding,DATASET_NAMES['embedding',embedding,class_group,sampling,test_size,MASKED[masking]]+'.joblib')
print("\t saving file :",DATASET_NAMES['embedding',embedding,class_group,sampling,test_size,MASKED[masking]])
......@@ -46,6 +48,8 @@ def embedding_transform(x_list, embedding, class_group, sampling, test_size, mas
X = count_vectorize_transform(trained_embedding,x_list)
elif embedding == 'w2v'or embedding == 'self_w2v':
X = word2vec_transform(trained_embedding,x_list)
elif embedding == 'elmo':
X = elmo_transform(trained_embedding, x_list)
return(X)
......@@ -77,15 +81,8 @@ def count_vectorize_fit_transform(x_list):
def count_vectorize_transform(vectorizer_binary,x_list):
X = vectorizer_binary.transform(x_list)
return X
def elmo_transform(x_list):
elmo = hub.load("https://tfhub.dev/google/elmo/3")
embeddings_list = []
for sent in x_list:
sent = preProcessAndTokenize(sent)
sent = ' '.join(sent)
embeddings = elmo.signatures["default"](tf.constant([sent]))
embeddings_list.append(np.mean(embeddings['word_emb'],1).flatten())
return (embeddings_list)
class MeanEmbeddingVectorizer(object):
def __init__(self, word2vec):
......@@ -137,6 +134,21 @@ def selftrained_word2vec_fit_transform(x_list):
return model , X
def elmo_transform(elmo, x_list):
X = []
for sent in x_list:
sent = preProcessAndTokenize(sent)
sent = ' '.join(sent)
embeddings = elmo.signatures["default"](tf.constant([sent]))
X.append(np.mean(embeddings['word_emb'],1).flatten())
return X
def elmo_fit_transform(x_list):
elmo = hub.load("https://tfhub.dev/google/elmo/3")
X = elmo_transform(elmo,x_list)
return elmo , X
......
......@@ -101,7 +101,7 @@ if __name__ == "__main__":
parser.add_argument("--class_group", default='medical', required=True, help = "choice of domain of occupations ('trial','medical')")
parser.add_argument("--sampling", required=True, help = "choice of sampling ('random', 'balanced')")
parser.add_argument("--embedding", required=True, help = "choice of embeddings to be used ('cv': count_vectorize(self-trained), 'w2v': word2vec_embedding(pre-trained), 'self_w2v':w2v(self-trained))")
parser.add_argument("--embedding", required=True, help = "choice of embeddings to be used ('cv': count_vectorize(self-trained), 'w2v': word2vec_embedding(pre-trained), 'self_w2v':w2v(self-trained), 'elmo':elmo(pre-trained))")
parser.add_argument("--model", default= 'svm', required=True, help = "choice of models to be trained ('svm', 'rf', 'nn')")
parser.add_argument("--test_size", default = 0.2, required=True, help = "proportion of data to be used for tesing")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment