Commit 144bc98d authored by Nishtha Jain's avatar Nishtha Jain
Browse files

predict.py updated

parent 3be43ca9
......@@ -8,53 +8,53 @@ def predict_one(X_test, embedding, class_group='medical', sampling='balanced', t
return prediction
def predict_all_models(X_test):
for embedding in ['cv','self_w2v','w2v']:
for sampling in ['random','balanced']:
for mask in ['raw','bio']:
pred = predict_one(X_test, embedding=embedding, sampling=sampling, masking=masking)
print("\nembedding :",embedding)
print("sampling :",sampling)
print("masking :",masking)
for i,x in enumerate(X_test):
print("sent :",x)
print("pred :",pred[i])
print()
for embedding in ['cv','self_w2v','w2v']:
for sampling in ['random','balanced']:
for mask in ['raw','bio']:
pred = predict_one(X_test, embedding=embedding, sampling=sampling, masking=masking)
print("\nembedding :",embedding)
print("sampling :",sampling)
print("masking :",masking)
for i,x in enumerate(X_test):
print("sent :",x)
print("pred :",pred[i])
print()
for sampling in ['random','balanced']:
pred = predict_one(X_test, embedding='elmo', sampling=sampling, masking=True)
print("\nembedding :",embedding)
print("sampling :",sampling)
print("masking :",masking)
for i,x in enumerate(X_test):
print("sent :",x)
print("pred :",pred[i])
print()
for sampling in ['random','balanced']:
pred = predict_one(X_test, embedding='elmo', sampling=sampling, masking=True)
print("\nembedding :",embedding)
print("sampling :",sampling)
print("masking :",masking)
for i,x in enumerate(X_test):
print("sent :",x)
print("pred :",pred[i])
print()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--masking", dest='masking',action='store_true', help = "for 'bio' data")
parser.add_argument("--masking", dest='masking',action='store_true', help = "for 'bio' data")
parser.add_argument("--no-masking", dest='masking',action='store_false', help = "for 'raw' data")
parser.add_argument("--sampling", help = "choice of sampling ('random', 'balanced')")
parser.add_argument("--embedding", help = "choice of embeddings to be used ('cv': count_vectorize(self-trained), 'w2v': word2vec_embedding(pre-trained), 'self_w2v':w2v(self-trained), 'elmo':elmo(pre-trained))")
parser.add_argument("--pred_all_models", help = 'yes or no')
if args.pred_all_models == 'yes':
predict_all_models(X_test)
else:
X_test = ['She works at the hospital','He works at the hospital']
if args.pred_all_models == 'yes':
predict_all_models(X_test)
else:
X_test = ['She works at the hospital','He works at the hospital']
pred = predict_one(X_test, embedding=args.embedding, sampling=args.sampling, masking=args.masking)
print("embedding :",args.embedding)
print("sampling :",args.sampling)
print("masking :",args.masking)
for i,x in enumerate(X_test):
print("sent :",x)
print("pred :",pred[i])
print()
pred = predict_one(X_test, embedding=args.embedding, sampling=args.sampling, masking=args.masking)
print("embedding :",args.embedding)
print("sampling :",args.sampling)
print("masking :",args.masking)
for i,x in enumerate(X_test):
print("sent :",x)
print("pred :",pred[i])
print()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment