Commit 37612bd8 authored by Nishtha Jain's avatar Nishtha Jain
Browse files

argparse added

parent 9077c14f
......@@ -29,4 +29,4 @@ train.py - contains the runnable flow of the project
run train.py
\ No newline at end of file
python train.py --load_data_from_saved False --embedding_train True --model_train True --predict True --evaluate True --class_group medical --sampling random --embedding cv --model svm --test_size 0.2 --masking True
\ No newline at end of file
......@@ -60,19 +60,29 @@ DATASET_NAMES = {
('embedding','cv','medical','random',0.2,'raw') : 'word_embeddings/embd__cv_med_ran_0.2_r',
('embedding','w2v','medical','random',0.2,'bio') : 'word_embeddings/embd__wv_med_ran_0.2_b',
('embedding','w2v','medical','random',0.2,'raw') : 'word_embeddings/embd__wv_med_ran_0.2_r',
('embedding','self_w2v','medical','random',0.2,'bio') : 'word_embeddings/embd__tv_med_ran_0.2_b',
('embedding','self_w2v','medical','random',0.2,'raw') : 'word_embeddings/embd__tv_med_ran_0.2_r',
('embedding','cv','medical','balanced',0.2,'bio') : 'word_embeddings/embd__cv_med_bal_0.2_b',
('embedding','cv','medical','balanced',0.2,'raw') : 'word_embeddings/embd__cv_med_bal_0.2_r',
('embedding','w2v','medical','balanced',0.2,'bio') : 'word_embeddings/embd__wv_med_bal_0.2_b',
('embedding','w2v','medical','balanced',0.2,'raw') : 'word_embeddings/embd__wv_med_bal_0.2_r',
('embedding','self_w2v','medical','balanced',0.2,'bio') : 'word_embeddings/embd__tv_med_bal_0.2_b',
('embedding','self_w2v','medical','balanced',0.2,'raw') : 'word_embeddings/embd__tv_med_bal_0.2_r',
('model','svm','cv','medical','random',0.2,'bio') : 'models/modl_sv_cv_med_ran_0.2_b',
('model','svm','cv','medical','random',0.2,'raw') : 'models/modl_sv_cv_med_ran_0.2_r',
('model','svm','w2v','medical','random',0.2,'bio') : 'models/modl_sv_wv_med_ran_0.2_b',
('model','svm','w2v','medical','random',0.2,'raw') : 'models/modl_sv_wv_med_ran_0.2_r',
('model','svm','self_w2v','medical','random',0.2,'bio') : 'models/modl_sv_tv_med_ran_0.2_b',
('model','svm','self_w2v','medical','random',0.2,'raw') : 'models/modl_sv_tv_med_ran_0.2_r',
('model','svm','cv','medical','balanced',0.2,'bio') : 'models/modl_sv_cv_med_bal_0.2_b',
('model','svm','cv','medical','balanced',0.2,'raw') : 'models/modl_sv_cv_med_bal_0.2_r',
('model','svm','w2v','medical','balanced',0.2,'bio') : 'models/modl_sv_wv_med_bal_0.2_b',
('model','svm','w2v','medical','balanced',0.2,'raw') : 'models/modl_sv_wv_med_bal_0.2_r'
('model','svm','w2v','medical','balanced',0.2,'raw') : 'models/modl_sv_wv_med_bal_0.2_r',
('model','svm','self_w2v','medical','balanced',0.2,'bio') : 'models/modl_sv_tv_med_bal_0.2_b',
('model','svm','self_w2v','medical','balanced',0.2,'raw') : 'models/modl_sv_tv_med_bal_0.2_r'
}
......@@ -80,12 +90,19 @@ PREDICTED_DATASET = {
# Naming convention : pred_[sv|rf|nn]_[cv|wv|tv]_[tri|med]_[ran|bal]_[test_spit]_[b|r]
('svm','cv','trial','balanced',0.2,'raw') : 'predicted_datasets/pred_sv_cv_tri_bal_0.2_r',
('svm','cv','trial','balanced',0.2,'bio') : 'predicted_datasets/pred_sv_cv_tri_bal_0.2_b',
('svm','w2v','trial','balanced',0.2,'raw') : 'predicted_datasets/pred_sv_wv_tri_bal_0.2_r',
('svm','w2v','trial','balanced',0.2,'bio') : 'predicted_datasets/pred_sv_wv_tri_bal_0.2_b',
('svm','self_w2v','trial','balanced',0.2,'raw') : 'predicted_datasets/pred_sv_tv_tri_bal_0.2_r',
('svm','self_w2v','trial','balanced',0.2,'bio') : 'predicted_datasets/pred_sv_tv_tri_bal_0.2_b',
('svm','cv','medical','balanced',0.2,'raw') : 'predicted_datasets/pred_sv_cv_med_bal_0.2_r',
('svm','cv','medical','balanced',0.2,'bio') : 'predicted_datasets/pred_sv_cv_med_bal_0.2_b',
('svm','w2v','medical','balanced',0.2,'raw') : 'predicted_datasets/pred_sv_wv_med_bal_0.2_r',
('svm','w2v','medical','balanced',0.2,'bio') : 'predicted_datasets/pred_sv_wv_med_bal_0.2_b',
('svm','self_w2v','medical','balanced',0.2,'raw') : 'predicted_datasets/pred_sv_tv_med_bal_0.2_r',
('svm','self_w2v','medical','balanced',0.2,'bio') : 'predicted_datasets/pred_sv_tv_med_bal_0.2_b'
}
......@@ -94,7 +111,18 @@ PLOT_NAMES = {
# Naming convention : plot_[tgp|aod]_[M|F]_[sv|rf|nn]_[cv|wv|tv]_[tri|med]_[ran|bal]_[test_spit]_[b|r]
('tgp','svm','cv','trial','balanced',0.2,'raw') : 'plots_and_graphs/plot_tgp_sv_cv_tri_bal_0.2_r',
('tgp','svm','cv','trial','balanced',0.2,'bio') : 'plots_and_graphs/plot_tgp_sv_cv_tri_bal_0.2_b',
('tgp','svm','w2v','trial','balanced',0.2,'raw') : 'plots_and_graphs/plot_tgp_sv_wv_tri_bal_0.2_r',
('tgp','svm','w2v','trial','balanced',0.2,'bio') : 'plots_and_graphs/plot_tgp_sv_wv_tri_bal_0.2_b',
('tgp','svm','self_w2v','trial','balanced',0.2,'raw') : 'plots_and_graphs/plot_tgp_sv_tv_tri_bal_0.2_r',
('tgp','svm','self_w2v','trial','balanced',0.2,'bio') : 'plots_and_graphs/plot_tgp_sv_tv_tri_bal_0.2_b',
('aod','svm','cv','trial','balanced',0.2,'raw') : 'plots_and_graphs/plot_aod_sv_cv_tri_bal_0.2_r',
('aod','svm','cv','trial','balanced',0.2,'bio') : 'plots_and_graphs/plot_aod_sv_cv_tri_bal_0.2_b',
('aod','svm','w2v','trial','balanced',0.2,'raw') : 'plots_and_graphs/plot_aod_sv_wv_tri_bal_0.2_r',
('aod','svm','w2v','trial','balanced',0.2,'bio') : 'plots_and_graphs/plot_aod_sv_wv_tri_bal_0.2_b',
('aod','svm','self_w2v','trial','balanced',0.2,'raw') : 'plots_and_graphs/plot_aod_sv_tv_tri_bal_0.2_r',
('aod','svm','self_w2v','trial','balanced',0.2,'bio') : 'plots_and_graphs/plot_aod_sv_tv_tri_bal_0.2_b'
}
\ No newline at end of file
......@@ -2,5 +2,4 @@ distlib==0.3.1
pymongo==3.11.4
virtualenv==20.4.2
gensim==4.0.1
pymongo
dnspython
\ No newline at end of file
......@@ -58,22 +58,45 @@ class_group -> choice of domain of occupations ('trial','medical')
sampling -> choice of sampling ('random', 'balanced')
embedding -> choice of embeddings to be used ('cv': count_vectorize(self-trained), 'w2v': word2vec_embedding(pre-trained), 'self_w2v':w2v(self-trained))
model -> choice of models to be trained ('svm', 'rf', 'nn')
test_size -> proportion of data to be used for tesing (IntegerValue)
test_size -> proportion of data to be used for tesing
masking -> True for 'bio' data and False for 'raw' data
'''
if __name__ == "__main__":
import time
start_time = time.time()
import argparse
# Initialize parser
parser = argparse.ArgumentParser()
# Adding optional argument
parser.add_argument("--load_data_from_saved", required=True, help = "True if saved data to be used and False if new data to be taken")
parser.add_argument("--embedding_train", default=True, help = "True to train new embedding and False to use the saved one")
parser.add_argument("--model_train", default=True, help = "True to train new model and False to use the saved one")
parser.add_argument("--predict", default=True, help = "True to perform predictions on the test set and False otherwise")
parser.add_argument("--evaluate", default=True, help = "True to perform bias evaluations on the test set and False otherwise")
parser.add_argument("--class_group", default='medical', required=True, help = "choice of domain of occupations ('trial','medical')")
parser.add_argument("--sampling", required=True, help = "choice of sampling ('random', 'balanced')")
parser.add_argument("--embedding", required=True, help = "choice of embeddings to be used ('cv': count_vectorize(self-trained), 'w2v': word2vec_embedding(pre-trained), 'self_w2v':w2v(self-trained))")
parser.add_argument("--model", default= 'svm', required=True, help = "choice of models to be trained ('svm', 'rf', 'nn')")
parser.add_argument("--test_size", default = 0.2, required=True, help = "proportion of data to be used for tesing")
parser.add_argument("--masking", required=True, help = "True for 'bio' data and False for 'raw' data")
# Read arguments from command line
args = parser.parse_args()
if args.load_data_from_saved:
print("Displaying load_data_from_saved as: % s" % args.load_data_from_saved)
main(load_data_from_saved = True,
embedding_train = False,
model_train = False,
predict = False,
evaluate = False,
class_group = 'trial',
sampling = 'balanced',
embedding = 'self_w2v',
model = 'svm',
test_size = 0.2,
masking = True)
main(load_data_from_saved = args.load_data_from_saved,
embedding_train = args.embedding_train,
model_train = args.model_train,
predict = args.predict,
evaluate = args.evaluate,
class_group = args.class_group,
sampling = args.sampling,
embedding = args.embedding,
model = args.model,
test_size = float(args.test_size),
masking = args.masking)
print("\n--- %s seconds ---" % (time.time() - start_time))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment