Commit 9a81d363 authored by Nishtha Jain's avatar Nishtha Jain
Browse files

predict all added in predict.py

parent 7ff58550
......@@ -11,7 +11,7 @@ def predict_all_models(X_test):
for embedding in ['cv','self_w2v','w2v']:
for sampling in ['random','balanced']:
for mask in ['raw','bio']:
for masking in [True,False]:
pred = predict_one(X_test, embedding=embedding, sampling=sampling, masking=masking)
print("\nembedding :",embedding)
......@@ -20,7 +20,7 @@ def predict_all_models(X_test):
for i,x in enumerate(X_test):
print("sent :",x)
print("pred :",pred[i])
print()
# print()
for sampling in ['random','balanced']:
pred = predict_one(X_test, embedding='elmo', sampling=sampling, masking=True)
......@@ -31,7 +31,7 @@ def predict_all_models(X_test):
for i,x in enumerate(X_test):
print("sent :",x)
print("pred :",pred[i])
print()
# print()
if __name__ == "__main__":
import argparse
......@@ -44,10 +44,11 @@ if __name__ == "__main__":
parser.add_argument("--embedding", help = "choice of embeddings to be used ('cv': count_vectorize(self-trained), 'w2v': word2vec_embedding(pre-trained), 'self_w2v':w2v(self-trained), 'elmo':elmo(pre-trained))")
parser.add_argument("--pred_all_models", help = 'yes or no')
args = parser.parse_args()
X_test = ['She works at the hospital','He works at the hospital']
if args.pred_all_models == 'yes':
predict_all_models(X_test)
else:
X_test = ['She works at the hospital','He works at the hospital']
pred = predict_one(X_test, embedding=args.embedding, sampling=args.sampling, masking=args.masking)
......@@ -57,4 +58,4 @@ if __name__ == "__main__":
for i,x in enumerate(X_test):
print("sent :",x)
print("pred :",pred[i])
print()
# print()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment