Commit 9a81d363 authored by Nishtha Jain's avatar Nishtha Jain
Browse files

predict all added in predict.py

parent 7ff58550
...@@ -11,7 +11,7 @@ def predict_all_models(X_test): ...@@ -11,7 +11,7 @@ def predict_all_models(X_test):
for embedding in ['cv','self_w2v','w2v']: for embedding in ['cv','self_w2v','w2v']:
for sampling in ['random','balanced']: for sampling in ['random','balanced']:
for mask in ['raw','bio']: for masking in [True,False]:
pred = predict_one(X_test, embedding=embedding, sampling=sampling, masking=masking) pred = predict_one(X_test, embedding=embedding, sampling=sampling, masking=masking)
print("\nembedding :",embedding) print("\nembedding :",embedding)
...@@ -20,7 +20,7 @@ def predict_all_models(X_test): ...@@ -20,7 +20,7 @@ def predict_all_models(X_test):
for i,x in enumerate(X_test): for i,x in enumerate(X_test):
print("sent :",x) print("sent :",x)
print("pred :",pred[i]) print("pred :",pred[i])
print() # print()
for sampling in ['random','balanced']: for sampling in ['random','balanced']:
pred = predict_one(X_test, embedding='elmo', sampling=sampling, masking=True) pred = predict_one(X_test, embedding='elmo', sampling=sampling, masking=True)
...@@ -31,7 +31,7 @@ def predict_all_models(X_test): ...@@ -31,7 +31,7 @@ def predict_all_models(X_test):
for i,x in enumerate(X_test): for i,x in enumerate(X_test):
print("sent :",x) print("sent :",x)
print("pred :",pred[i]) print("pred :",pred[i])
print() # print()
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
...@@ -43,11 +43,12 @@ if __name__ == "__main__": ...@@ -43,11 +43,12 @@ if __name__ == "__main__":
parser.add_argument("--sampling", help = "choice of sampling ('random', 'balanced')") parser.add_argument("--sampling", help = "choice of sampling ('random', 'balanced')")
parser.add_argument("--embedding", help = "choice of embeddings to be used ('cv': count_vectorize(self-trained), 'w2v': word2vec_embedding(pre-trained), 'self_w2v':w2v(self-trained), 'elmo':elmo(pre-trained))") parser.add_argument("--embedding", help = "choice of embeddings to be used ('cv': count_vectorize(self-trained), 'w2v': word2vec_embedding(pre-trained), 'self_w2v':w2v(self-trained), 'elmo':elmo(pre-trained))")
parser.add_argument("--pred_all_models", help = 'yes or no') parser.add_argument("--pred_all_models", help = 'yes or no')
args = parser.parse_args()
X_test = ['She works at the hospital','He works at the hospital']
if args.pred_all_models == 'yes': if args.pred_all_models == 'yes':
predict_all_models(X_test) predict_all_models(X_test)
else: else:
X_test = ['She works at the hospital','He works at the hospital']
pred = predict_one(X_test, embedding=args.embedding, sampling=args.sampling, masking=args.masking) pred = predict_one(X_test, embedding=args.embedding, sampling=args.sampling, masking=args.masking)
...@@ -57,4 +58,4 @@ if __name__ == "__main__": ...@@ -57,4 +58,4 @@ if __name__ == "__main__":
for i,x in enumerate(X_test): for i,x in enumerate(X_test):
print("sent :",x) print("sent :",x)
print("pred :",pred[i]) print("pred :",pred[i])
print() # print()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment