Commit 82b14bdc authored by Sparsh Jauhari's avatar Sparsh Jauhari 💬
Browse files
parents 10dc1896 9a81d363
......@@ -12,6 +12,14 @@ to add pretrained word embedding models, write in terminal:
to get debiased version of pretrained word2Vec embedding, use the link: <br /> https://drive.google.com/file/d/1_PvT4ZvtZjhq4HPywA8-u06epht9ccOw/view?usp=sharing
<br /> OR <br />
`cd word_embeddings` <br />
gdown https://drive.google.com/uc?id=0B5vZVlu2WoS5ZTBSekpUX0RSNDg
install requirements.txt
......
......@@ -11,7 +11,7 @@ def predict_all_models(X_test):
for embedding in ['cv','self_w2v','w2v']:
for sampling in ['random','balanced']:
for mask in ['raw','bio']:
for masking in [True,False]:
pred = predict_one(X_test, embedding=embedding, sampling=sampling, masking=masking)
print("\nembedding :",embedding)
......@@ -20,7 +20,7 @@ def predict_all_models(X_test):
for i,x in enumerate(X_test):
print("sent :",x)
print("pred :",pred[i])
print()
# print()
for sampling in ['random','balanced']:
pred = predict_one(X_test, embedding='elmo', sampling=sampling, masking=True)
......@@ -31,7 +31,7 @@ def predict_all_models(X_test):
for i,x in enumerate(X_test):
print("sent :",x)
print("pred :",pred[i])
print()
# print()
if __name__ == "__main__":
import argparse
......@@ -43,11 +43,12 @@ if __name__ == "__main__":
parser.add_argument("--sampling", help = "choice of sampling ('random', 'balanced')")
parser.add_argument("--embedding", help = "choice of embeddings to be used ('cv': count_vectorize(self-trained), 'w2v': word2vec_embedding(pre-trained), 'self_w2v':w2v(self-trained), 'elmo':elmo(pre-trained))")
parser.add_argument("--pred_all_models", help = 'yes or no')
args = parser.parse_args()
X_test = ['She works at the hospital','He works at the hospital']
if args.pred_all_models == 'yes':
predict_all_models(X_test)
else:
X_test = ['She works at the hospital','He works at the hospital']
pred = predict_one(X_test, embedding=args.embedding, sampling=args.sampling, masking=args.masking)
......@@ -57,4 +58,4 @@ if __name__ == "__main__":
for i,x in enumerate(X_test):
print("sent :",x)
print("pred :",pred[i])
print()
# print()
......@@ -2,4 +2,5 @@ distlib==0.3.1
pymongo==3.11.4
virtualenv==20.4.2
gensim==4.0.1
dnspython
\ No newline at end of file
dnspython
gdown
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment